diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 400284e487..77cf25d4ca 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -66,6 +66,8 @@ from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( CSPDarkNetBackbone, ) @@ -257,7 +259,17 @@ from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( + StableDiffusion3TextToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) from keras_hub.src.models.t5.t5_backbone import T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier @@ -265,6 +277,7 @@ from keras_hub.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) +from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 9a011836ed..d0edf9dabd 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -21,6 +21,7 @@ from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) diff --git a/keras_hub/src/models/stable_diffusion_v3/__init__.py b/keras_hub/src/models/clip/__init__.py similarity index 100% rename from keras_hub/src/models/stable_diffusion_v3/__init__.py rename to keras_hub/src/models/clip/__init__.py diff --git a/keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py b/keras_hub/src/models/clip/clip_encoder_block.py similarity index 89% rename from keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py rename to keras_hub/src/models/clip/clip_encoder_block.py index 6fe4bc9b59..863b705515 100644 --- a/keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +++ b/keras_hub/src/models/clip/clip_encoder_block.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from keras import dtype_policies from keras import layers from keras import ops @@ -43,7 +44,7 @@ def __init__( intermediate_activation = quick_gelu self.layer_norm_1 = layers.LayerNormalization( - epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1" + epsilon=1e-5, dtype="float32", name="layer_norm_1" ) self.attention = layers.MultiHeadAttention( num_heads, @@ -52,7 +53,7 @@ def __init__( name="attention", ) self.layer_norm_2 = layers.LayerNormalization( - epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2" + epsilon=1e-5, dtype="float32", name="layer_norm_2" ) self.dense_1 = layers.Dense( self.intermediate_dim, dtype=self.dtype_policy, name="dense_1" @@ -67,6 +68,11 @@ def __init__( def build(self, input_shape): self.layer_norm_1.build(input_shape) self.attention.build(input_shape, input_shape, input_shape) + # Before Keras 3.2, there was no setter for `dtype_policy`. Directly + # assign a `DTypePolicy` instead. + self.attention._softmax.dtype_policy = dtype_policies.DTypePolicy( + "float32" + ) self.layer_norm_2.build(input_shape) self.dense_1.build(input_shape) input_shape = self.dense_1.compute_output_shape(input_shape) diff --git a/keras_hub/src/models/clip/clip_preprocessor.py b/keras_hub/src/models/clip/clip_preprocessor.py new file mode 100644 index 0000000000..c8632e0334 --- /dev/null +++ b/keras_hub/src/models/clip/clip_preprocessor.py @@ -0,0 +1,147 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.utils.tensor_utils import preprocessing_function + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.CLIPPreprocessor") +class CLIPPreprocessor(Preprocessor): + """CLIP preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `CLIPPreprocessor` forces the input to have only one segment, as CLIP is + mainly used for generation tasks. For tasks having multi-segment inputs + like "glue/mnli", please use a model designed for classification purposes + such as BERT or RoBERTa. + + Args: + tokenizer: A `keras_hub.models.CLIPTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + to_lower: bool. Whether to lower the inputs. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + """ + + # TODO: Add example once we have a CLIP model. + + tokenizer_cls = CLIPTokenizer + + def __init__( + self, + tokenizer, + sequence_length=77, + add_start_token=True, + add_end_token=True, + to_lower=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.to_lower = to_lower + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.end_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + sequence_length = sequence_length or self.sequence_length + if self.to_lower: + x = tf.strings.lower(x) + token_ids, padding_mask = self.packer( + self.tokenizer(x), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + "to_lower": self.to_lower, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_hub/src/models/stable_diffusion_v3/clip_preprocessor_test.py b/keras_hub/src/models/clip/clip_preprocessor_test.py similarity index 83% rename from keras_hub/src/models/stable_diffusion_v3/clip_preprocessor_test.py rename to keras_hub/src/models/clip/clip_preprocessor_test.py index 8585752b84..8321d7be75 100644 --- a/keras_hub/src/models/stable_diffusion_v3/clip_preprocessor_test.py +++ b/keras_hub/src/models/clip/clip_preprocessor_test.py @@ -13,12 +13,8 @@ # limitations under the License. import pytest -from keras_hub.src.models.stable_diffusion_v3.clip_preprocessor import ( - CLIPPreprocessor, -) -from keras_hub.src.models.stable_diffusion_v3.clip_tokenizer import ( - CLIPTokenizer, -) +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.tests.test_case import TestCase @@ -43,7 +39,7 @@ def test_preprocessor_basics(self): input_data=self.input_data, expected_output={ "token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]], - "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], }, ) @@ -54,17 +50,16 @@ def test_no_start_end_token(self): sequence_length=8, add_start_token=False, add_end_token=False, - pad_with_end_token=False, ) x = preprocessor(input_data) - self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 4, 4, 4, 4]] * 4) self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) def test_sequence_length_override(self): input_data = " airplane airport" preprocessor = CLIPPreprocessor(**self.init_kwargs) - x = preprocessor(input_data, sequence_length=4) - self.assertAllEqual(x["token_ids"], [5, 1, 2, 1]) + x = preprocessor(input_data, sequence_length=5) + self.assertAllEqual(x["token_ids"], [5, 1, 2, 1, 4]) @pytest.mark.kaggle_key_required @pytest.mark.extra_large diff --git a/keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py b/keras_hub/src/models/clip/clip_text_encoder.py similarity index 51% rename from keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py rename to keras_hub/src/models/clip/clip_text_encoder.py index 77cfc7e98e..0376777bb6 100644 --- a/keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +++ b/keras_hub/src/models/clip/clip_text_encoder.py @@ -11,21 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import keras from keras import layers -from keras import ops from keras_hub.src.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, ) -from keras_hub.src.models.stable_diffusion_v3.clip_encoder_block import ( - CLIPEncoderBlock, -) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock + +class CLIPTextEncoder(Backbone): + """CLIP text core network with hyperparameters. + + Args: + vocabulary_size: int. The size of the token vocabulary. + embedding_dim: int. The output dimension of the embedding layer. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + intermediate_activation: activation function. The activation that + is used for the first Dense layer in a two-layer feedforward network + for each transformer. + intermediate_output_index: optional int. The index of the intermediate + output. If specified, the output will become a dictionary with two + keys `"sequence_output"` and `"intermediate_output"`. + max_sequence_length: int. The maximum sequence length that this encoder + can consume. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + """ -class CLIPTextEncoder(keras.Model): def __init__( self, + vocabulary_size, embedding_dim, hidden_dim, num_layers, @@ -33,9 +58,9 @@ def __init__( intermediate_dim, intermediate_activation="quick_gelu", intermediate_output_index=None, - vocabulary_size=49408, - sequence_length=77, + max_sequence_length=77, dtype=None, + name=None, **kwargs, ): if ( @@ -44,13 +69,17 @@ def __init__( ): intermediate_output_index += num_layers + # `prefix` is used to prevent duplicate name when utilizing multiple + # CLIP models within a single model, such as in StableDiffusion3. + prefix = str(name) + "_" if name is not None else "" + # === Layers === self.embedding = TokenAndPositionEmbedding( vocabulary_size=vocabulary_size, - sequence_length=sequence_length, + sequence_length=max_sequence_length, embedding_dim=embedding_dim, dtype=dtype, - name="embedding", + name=f"{prefix}embedding", ) self.encoder_layers = [ CLIPEncoderBlock( @@ -59,59 +88,44 @@ def __init__( intermediate_dim, intermediate_activation, dtype=dtype, + name=f"{prefix}encoder_block_{i}", ) - for _ in range(num_layers) + for i in range(num_layers) ] self.layer_norm = layers.LayerNormalization( - epsilon=0.00001, dtype=dtype, name="layer_norm" - ) - self.text_projection = layers.Dense( - hidden_dim, - use_bias=False, - dtype=dtype, - name="text_projection", + epsilon=1e-6, dtype="float32", name=f"{prefix}layer_norm" ) # === Functional Model === - encoder_token_ids = layers.Input( - shape=(sequence_length,), dtype="int32", name="encoder_token_ids" + token_id_input = layers.Input( + shape=(None,), dtype="int32", name="token_ids" ) - x = self.embedding(encoder_token_ids) - encoder_intermediate_output = None - # Encoder. + x = self.embedding(token_id_input) + intermediate_output = None for i, block in enumerate(self.encoder_layers): x = block(x) if i == intermediate_output_index: - encoder_intermediate_output = x + intermediate_output = x x = self.layer_norm(x) - encoder_output = x - if encoder_intermediate_output is not None: - encoder_intermediate_output = self.layer_norm( - encoder_intermediate_output - ) - # Projection. - indices = ops.expand_dims( - ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1 - ) - pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1) - pooled_output = ops.squeeze(pooled_output, axis=1) - projection_output = self.text_projection(pooled_output) + sequence_output = x - outputs = { - "encoder_sequence_output": encoder_output, - "encoder_pooled_output": pooled_output, - "encoder_projection_output": projection_output, - } if intermediate_output_index is not None: - outputs["encoder_intermediate_output"] = encoder_intermediate_output - + outputs = { + "sequence_output": sequence_output, + "intermediate_output": intermediate_output, + } + else: + outputs = sequence_output super().__init__( - inputs={"encoder_token_ids": encoder_token_ids}, + inputs={"token_ids": token_id_input}, outputs=outputs, + name=name, **kwargs, ) # === Config === + self.vocabulary_size = vocabulary_size + self.max_sequence_length = max_sequence_length self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers @@ -119,22 +133,12 @@ def __init__( self.intermediate_dim = intermediate_dim self.intermediate_activation = intermediate_activation self.intermediate_output_index = intermediate_output_index - self.vocabulary_size = vocabulary_size - self.sequence_length = sequence_length - - if dtype is not None: - try: - self.dtype_policy = keras.dtype_policies.get(dtype) - # Before Keras 3.2, there is no `keras.dtype_policies.get`. - except AttributeError: - if isinstance(dtype, keras.DTypePolicy): - dtype = dtype.name - self.dtype_policy = keras.DTypePolicy(dtype) def get_config(self): config = super().get_config() config.update( { + "vocabulary_size": self.vocabulary_size, "embedding_dim": self.embedding_dim, "hidden_dim": self.hidden_dim, "num_layers": self.num_layers, @@ -142,8 +146,7 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "intermediate_activation": self.intermediate_activation, "intermediate_output_index": self.intermediate_output_index, - "vocabulary_size": self.vocabulary_size, - "sequence_length": self.sequence_length, + "max_sequence_length": self.max_sequence_length, } ) return config diff --git a/keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py b/keras_hub/src/models/clip/clip_tokenizer.py similarity index 65% rename from keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py rename to keras_hub/src/models/clip/clip_tokenizer.py index a9e17e8bda..2a594bda90 100644 --- a/keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +++ b/keras_hub/src/models/clip/clip_tokenizer.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe +from keras_hub.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -21,10 +24,51 @@ tf = None +@keras_hub_export( + [ + "keras_hub.tokenizers.CLIPTokenizer", + "keras_hub.models.CLIPTokenizer", + ] +) class CLIPTokenizer(BytePairTokenizer): - def __init__(self, vocabulary=None, merges=None, **kwargs): - self.start_token = "<|startoftext|>" - self.end_token = "<|endoftext|>" + """A CLIP tokenizer using Byte-Pair Encoding subword segmentation. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by CLIP + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a CLIP preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + pad_with_end_token: bool. Whether to pad the output with `end_token`. + """ + + # TODO: Add example and `backbone_cls` once we have a CLIP model. + + backbone_cls = None + + def __init__( + self, + vocabulary=None, + merges=None, + pad_with_end_token=False, + **kwargs, + ): + self._add_special_token("<|startoftext|>", "start_token") + self._add_special_token("<|endoftext|>", "end_token") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, @@ -33,35 +77,21 @@ def __init__(self, vocabulary=None, merges=None, **kwargs): **kwargs, ) - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - self.pad_token_id = 0 - else: - self.end_token_id = None - self.start_token_id = None - self.pad_token_id = None + # When `pad_with_end_token` is True, we need to access the vocabulary, + # so the check is required. + if pad_with_end_token: + self._check_vocabulary() + self.pad_token_id = self.end_token_id + self.pad_with_end_token = pad_with_end_token def _bpe_merge_and_update_cache(self, tokens): """Process unseen tokens and add to cache.""" words = self._transform_bytes(tokens) - # In StableDiffusionV3, we need to add `` to the last word. + # In CLIP, we need to add `` to the last word. words = tf.strings.reduce_join(words, axis=1, separator=" ") words = tf.strings.join([words, ""]) words = tf.strings.split(words, sep=" ") - tokenized_words = self._bpe_merge(words) # For each word, join all its token by a whitespace, @@ -71,17 +101,20 @@ def _bpe_merge_and_update_cache(self, tokens): ) self.cache.insert(tokens, tokenized_words) + @preprocessing_function def tokenize(self, inputs): self._check_vocabulary() - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): - inputs = tf.convert_to_tensor(inputs) - if self.add_prefix_space: inputs = tf.strings.join([" ", inputs]) - scalar_input = inputs.shape.rank == 0 - if scalar_input: + unbatched = inputs.shape.rank == 0 + if unbatched: inputs = tf.expand_dims(inputs, 0) + if inputs.shape.rank > 1: + raise ValueError( + "`tokenize()` inputs should be a string, list of strings, or " + f"string tensor with rank < 2. Received: {inputs}" + ) raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) @@ -131,12 +164,13 @@ def process_unseen_tokens(): tokens = tokens.to_tensor(shape=output_shape) # Convert to a dense output if input in scalar - if scalar_input: + if unbatched: tokens = tf.squeeze(tokens, 0) tf.ensure_shape(tokens, shape=[self.sequence_length]) return tokens + @preprocessing_function def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, _ = convert_to_ragged_batch(inputs) @@ -160,6 +194,11 @@ def detokenize(self, inputs): def get_config(self): config = super().get_config() + config.update( + { + "pad_with_end_token": self.pad_with_end_token, + } + ) # In the constructor, we pass the list of special tokens to the # `unsplittable_tokens` arg of the superclass' constructor. Hence, we # delete it from the config here. diff --git a/keras_hub/src/models/stable_diffusion_v3/clip_tokenizer_test.py b/keras_hub/src/models/clip/clip_tokenizer_test.py similarity index 88% rename from keras_hub/src/models/stable_diffusion_v3/clip_tokenizer_test.py rename to keras_hub/src/models/clip/clip_tokenizer_test.py index 77a5db8780..8ecea96a58 100644 --- a/keras_hub/src/models/stable_diffusion_v3/clip_tokenizer_test.py +++ b/keras_hub/src/models/clip/clip_tokenizer_test.py @@ -13,9 +13,7 @@ # limitations under the License. import pytest -from keras_hub.src.models.stable_diffusion_v3.clip_tokenizer import ( - CLIPTokenizer, -) +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.tests.test_case import TestCase @@ -40,6 +38,12 @@ def test_tokenizer_basics(self): expected_detokenize_output=["airplane", "airport"], ) + def test_pad_with_end_token(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs["pad_with_end_token"] = True + tokenizer = CLIPTokenizer(**init_kwargs) + self.assertEqual(tokenizer.pad_token_id, tokenizer.end_token_id) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): CLIPTokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"]) diff --git a/keras_hub/src/models/stable_diffusion_3/__init__.py b/keras_hub/src/models/stable_diffusion_3/__init__.py new file mode 100644 index 0000000000..fd48fde00f --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py b/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py new file mode 100644 index 0000000000..a00988c4cd --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py @@ -0,0 +1,93 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers +from keras import ops + + +class FlowMatchEulerDiscreteScheduler(layers.Layer): + """Flow-matching sampling euler scheduler. + + This layer is used to compute the discrete sigmas for the diffusion chain. + Typically, the sigma refers to the amount of noise added during the + diffusion process. + + Args: + num_train_timesteps: int. The number of diffusion steps to train the + model. + shift: float. The shift value for the timestep schedule. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Call arguments: + inputs: The current step of the diffusion process. + num_steps: The total number of steps in the diffusion process. + + References: + - [Common Diffusion Noise Schedules and Sample Steps are Flawed]( + https://arxiv.org/abs/2305.08891). + - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( + https://arxiv.org/abs/2403.03206). + """ + + def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs): + super().__init__(**kwargs) + self.num_train_timesteps = int(num_train_timesteps) + self.shift = float(shift) + + timesteps = ops.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype="float32" + ) + timesteps = ops.flip(timesteps, axis=0) + sigmas = self._timestep_to_sigma(timesteps) + + self.timesteps = ops.multiply(sigmas, num_train_timesteps) + self.sigma_min = sigmas[-1] + self.sigma_max = sigmas[0] + + def _sigma_to_timestep(self, sigma): + return sigma * self.num_train_timesteps + + def _timestep_to_sigma(self, timestep): + sigma = ops.divide(timestep, self.num_train_timesteps) + if self.shift != 1.0: + sigma = ops.divide( + ops.multiply(self.shift, sigma), + ops.add(1, ops.multiply(self.shift - 1.0, sigma)), + ) + return sigma + + def call(self, inputs, num_steps): + start = self._sigma_to_timestep(self.sigma_max) + end = self._sigma_to_timestep(self.sigma_min) + step_size = ops.divide( + ops.subtract(end, start), ops.subtract(num_steps, 1) + ) + timestep = ops.add(start, ops.multiply(inputs, step_size)) + sigma = ops.maximum(self._timestep_to_sigma(timestep), 0.0) + timestep = self._sigma_to_timestep(sigma) + return sigma, timestep + + def get_config(self): + config = super().get_config() + config.update( + { + "num_train_timesteps": self.num_train_timesteps, + "shift": self.shift, + } + ) + return config + + def compute_output_shape(self): + # Returns a tuple of (sigma, timestep). + return (None,), (None,) diff --git a/keras_hub/src/models/stable_diffusion_v3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py similarity index 51% rename from keras_hub/src/models/stable_diffusion_v3/mmdit.py rename to keras_hub/src/models/stable_diffusion_3/mmdit.py index 619888baf1..c3c879788f 100644 --- a/keras_hub/src/models/stable_diffusion_v3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -19,7 +19,8 @@ from keras import ops from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding -from keras_hub.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.utils.keras_utils import gelu_approximate from keras_hub.src.utils.keras_utils import standardize_data_format @@ -79,8 +80,8 @@ def call(self, inputs, height=None, width=None): width = width or self.width shape = ops.shape(inputs) feature_length = shape[-1] - top = ops.floor_divide(self.height - height, 2) - left = ops.floor_divide(self.width - width, 2) + top = ops.cast(ops.floor_divide(self.height - height, 2), "int32") + left = ops.cast(ops.floor_divide(self.width - width, 2), "int32") position_embedding = ops.convert_to_tensor(self.position_embeddings) position_embedding = ops.reshape( position_embedding, (self.height, self.width, feature_length) @@ -166,6 +167,305 @@ def compute_output_shape(self, inputs_shape): return output_shape +class DismantledBlock(layers.Layer): + def __init__( + self, + num_heads, + hidden_dim, + mlp_ratio=4.0, + use_projection=True, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_ratio = mlp_ratio + self.use_projection = use_projection + + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + num_modulations = 6 if use_projection else 2 + self.num_modulations = num_modulations + + self.adaptive_norm_modulation = models.Sequential( + [ + layers.Activation("silu", dtype=self.dtype_policy), + layers.Dense( + num_modulations * hidden_dim, dtype=self.dtype_policy + ), + ], + name="adaptive_norm_modulation", + ) + self.norm1 = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype="float32", + name="norm1", + ) + self.attention_qkv = layers.Dense( + hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" + ) + if use_projection: + self.attention_proj = layers.Dense( + hidden_dim, dtype=self.dtype_policy, name="attention_proj" + ) + self.norm2 = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype="float32", + name="norm2", + ) + self.mlp = models.Sequential( + [ + layers.Dense( + mlp_hidden_dim, + activation=gelu_approximate, + dtype=self.dtype_policy, + ), + layers.Dense( + hidden_dim, + dtype=self.dtype_policy, + ), + ], + name="mlp", + ) + + def build(self, inputs_shape, timestep_embedding): + self.adaptive_norm_modulation.build(timestep_embedding) + self.attention_qkv.build(inputs_shape) + self.norm1.build(inputs_shape) + if self.use_projection: + self.attention_proj.build(inputs_shape) + self.norm2.build(inputs_shape) + self.mlp.build(inputs_shape) + + def _modulate(self, inputs, shift, scale): + shift = ops.expand_dims(shift, axis=1) + scale = ops.expand_dims(scale, axis=1) + return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) + + def _compute_pre_attention(self, inputs, timestep_embedding, training=None): + batch_size = ops.shape(inputs)[0] + if self.use_projection: + modulation = self.adaptive_norm_modulation( + timestep_embedding, training=training + ) + modulation = ops.reshape( + modulation, (batch_size, 6, self.hidden_dim) + ) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = ops.unstack(modulation, 6, axis=1) + qkv = self.attention_qkv( + self._modulate(self.norm1(inputs), shift_msa, scale_msa), + training=training, + ) + qkv = ops.reshape( + qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) + ) + q, k, v = ops.unstack(qkv, 3, axis=2) + return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + modulation = self.adaptive_norm_modulation( + timestep_embedding, training=training + ) + modulation = ops.reshape( + modulation, (batch_size, 2, self.hidden_dim) + ) + shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1) + qkv = self.attention_qkv( + self._modulate(self.norm1(inputs), shift_msa, scale_msa), + training=training, + ) + qkv = ops.reshape( + qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) + ) + q, k, v = ops.unstack(qkv, 3, axis=2) + return (q, k, v) + + def _compute_post_attention( + self, inputs, inputs_intermediates, training=None + ): + x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates + attn = self.attention_proj(inputs, training=training) + x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn)) + x = ops.add( + x, + ops.multiply( + ops.expand_dims(gate_mlp, axis=1), + self.mlp( + self._modulate(self.norm2(x), shift_mlp, scale_mlp), + training=training, + ), + ), + ) + return x + + def call( + self, + inputs, + timestep_embedding=None, + inputs_intermediates=None, + pre_attention=True, + training=None, + ): + if pre_attention: + return self._compute_pre_attention( + inputs, timestep_embedding, training=training + ) + else: + return self._compute_post_attention( + inputs, inputs_intermediates, training=training + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_ratio": self.mlp_ratio, + "use_projection": self.use_projection, + } + ) + return config + + +class MMDiTBlock(layers.Layer): + def __init__( + self, + num_heads, + hidden_dim, + mlp_ratio=4.0, + use_context_projection=True, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_ratio = mlp_ratio + self.use_context_projection = use_context_projection + + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim) + self._dot_product_equation = "aecd,abcd->acbe" + self._combine_equation = "acbe,aecd->abcd" + + self.x_block = DismantledBlock( + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_ratio=mlp_ratio, + use_projection=True, + dtype=self.dtype_policy, + name="x_block", + ) + self.context_block = DismantledBlock( + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_ratio=mlp_ratio, + use_projection=use_context_projection, + dtype=self.dtype_policy, + name="context_block", + ) + self.softmax = layers.Softmax(dtype="float32") + + def build(self, inputs_shape, context_shape, timestep_embedding_shape): + self.x_block.build(inputs_shape, timestep_embedding_shape) + self.context_block.build(context_shape, timestep_embedding_shape) + + def _compute_attention(self, query, key, value): + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) + ) + attention_scores = ops.einsum(self._dot_product_equation, key, query) + attention_scores = self.softmax(attention_scores) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + batch_size = ops.shape(attention_output)[0] + attention_output = ops.reshape( + attention_output, (batch_size, -1, self.num_heads * self.head_dim) + ) + return attention_output + + def call(self, inputs, context, timestep_embedding, training=None): + # Compute pre-attention. + x = inputs + if self.use_context_projection: + context_qkv, context_intermediates = self.context_block( + context, + timestep_embedding=timestep_embedding, + training=training, + ) + else: + context_qkv = self.context_block( + context, + timestep_embedding=timestep_embedding, + training=training, + ) + context_len = ops.shape(context_qkv[0])[1] + x_qkv, x_intermediates = self.x_block( + x, timestep_embedding=timestep_embedding, training=training + ) + q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1) + k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1) + v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1) + + # Compute attention. + attention = self._compute_attention(q, k, v) + context_attention = attention[:, :context_len] + x_attention = attention[:, context_len:] + + # Compute post-attention. + x = self.x_block( + x_attention, + inputs_intermediates=x_intermediates, + pre_attention=False, + training=training, + ) + if self.use_context_projection: + context = self.context_block( + context_attention, + inputs_intermediates=context_intermediates, + pre_attention=False, + training=training, + ) + return x, context + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_ratio": self.mlp_ratio, + "use_context_projection": self.use_context_projection, + } + ) + return config + + def compute_output_shape( + self, inputs_shape, context_shape, timestep_embedding_shape + ): + if self.use_context_projection: + return inputs_shape, context_shape + else: + return inputs_shape + + class OutputLayer(layers.Layer): def __init__(self, hidden_dim, output_dim, **kwargs): super().__init__(**kwargs) @@ -186,11 +486,11 @@ def __init__(self, hidden_dim, output_dim, **kwargs): epsilon=1e-6, center=False, scale=False, - dtype=self.dtype_policy, + dtype="float32", name="norm", ) self.output_dense = layers.Dense( - output_dim, # patch_size ** 2 * input_channels + output_dim, use_bias=True, dtype=self.dtype_policy, name="output_dense", @@ -227,6 +527,11 @@ def get_config(self): ) return config + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + outputs_shape[-1] = self.output_dim + return outputs_shape + class Unpatch(layers.Layer): def __init__(self, patch_size, output_dim, **kwargs): @@ -263,18 +568,48 @@ def compute_output_shape(self, inputs_shape): return [inputs_shape[0], None, None, self.output_dim] -class MMDiT(keras.Model): +class MMDiT(Backbone): + """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3. + + MMDiT is introduced in [ + Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( + https://arxiv.org/abs/2403.03206). + + Args: + patch_size: int. The size of each square patch in the input image. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + position_size: int. The size of the height and width for the position + embedding. + mlp_ratio: float. The ratio of the mlp hidden dim to the transformer + latent_shape: tuple. The shape of the latent image. + context_shape: tuple. The shape of the context. + pooled_projection_shape: tuple. The shape of the pooled projection. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + """ + def __init__( self, patch_size, - num_heads, hidden_dim, - depth, + num_layers, + num_heads, position_size, - output_dim, mlp_ratio=4.0, latent_shape=(64, 64, 16), - context_shape=(1024, 4096), + context_shape=(None, 4096), pooled_projection_shape=(2048,), data_format=None, dtype=None, @@ -287,6 +622,7 @@ def __init__( ) image_height = latent_shape[0] // patch_size image_width = latent_shape[1] // patch_size + output_dim = latent_shape[-1] output_dim_in_final = patch_size**2 * output_dim data_format = standardize_data_format(data_format) if data_format != "channels_last": @@ -331,11 +667,11 @@ def __init__( num_heads, hidden_dim, mlp_ratio, - use_context_projection=not (i == depth - 1), + use_context_projection=not (i == num_layers - 1), dtype=dtype, name=f"joint_block_{i}", ) - for i in range(depth) + for i in range(num_layers) ] self.output_layer = OutputLayer( hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer" @@ -391,33 +727,22 @@ def __init__( self.patch_size = patch_size self.num_heads = num_heads self.hidden_dim = hidden_dim - self.depth = depth + self.num_layers = num_layers self.position_size = position_size - self.output_dim = output_dim self.mlp_ratio = mlp_ratio self.latent_shape = latent_shape self.context_shape = context_shape self.pooled_projection_shape = pooled_projection_shape - if dtype is not None: - try: - self.dtype_policy = keras.dtype_policies.get(dtype) - # Before Keras 3.2, there is no `keras.dtype_policies.get`. - except AttributeError: - if isinstance(dtype, keras.DTypePolicy): - dtype = dtype.name - self.dtype_policy = keras.DTypePolicy(dtype) - def get_config(self): config = super().get_config() config.update( { "patch_size": self.patch_size, - "num_heads": self.num_heads, "hidden_dim": self.hidden_dim, - "depth": self.depth, + "num_layers": self.num_layers, + "num_heads": self.num_heads, "position_size": self.position_size, - "output_dim": self.output_dim, "mlp_ratio": self.mlp_ratio, "latent_shape": self.latent_shape, "context_shape": self.context_shape, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py new file mode 100644 index 0000000000..883c2b11fd --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -0,0 +1,630 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import ( + FlowMatchEulerDiscreteScheduler, +) +from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT +from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import ( + VAEImageDecoder, +) +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class CLIPProjection(layers.Layer): + def __init__(self, hidden_dim, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + + self.dense = layers.Dense( + hidden_dim, + use_bias=False, + dtype=self.dtype_policy, + name="dense", + ) + + def build(self, inputs_shape, token_ids_shape): + inputs_shape = list(inputs_shape) + self.dense.build([None, inputs_shape[-1]]) + + # Assign identity matrix to the kernel as default. + self.dense._kernel.assign(ops.eye(self.hidden_dim)) + + def call(self, inputs, token_ids): + indices = ops.expand_dims( + ops.cast(ops.argmax(token_ids, axis=-1), "int32"), axis=-1 + ) + pooled_output = ops.take_along_axis(inputs, indices[:, :, None], axis=1) + pooled_output = ops.squeeze(pooled_output, axis=1) + return self.dense(pooled_output) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + } + ) + return config + + def compute_output_shape(self, inputs_shape): + return (inputs_shape[0], self.hidden_dim) + + +class ClassifierFreeGuidanceConcatenate(layers.Layer): + def __init__(self, axis=0, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call( + self, + latents, + positive_contexts, + negative_contexts, + positive_pooled_projections, + negative_pooled_projections, + timestep, + ): + timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1]) + latents = ops.concatenate([latents, latents], axis=self.axis) + contexts = ops.concatenate( + [positive_contexts, negative_contexts], axis=self.axis + ) + pooled_projections = ops.concatenate( + [positive_pooled_projections, negative_pooled_projections], + axis=self.axis, + ) + timesteps = ops.concatenate([timestep, timestep], axis=self.axis) + return latents, contexts, pooled_projections, timesteps + + def get_config(self): + return super().get_config() + + +class ClassifierFreeGuidance(layers.Layer): + """Perform classifier free guidance. + + This layer expects the inputs to be a concatenation of positive and negative + (or empty) noise. The computation applies the classifier-free guidance + scale. + + Args: + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Call arguments: + inputs: A concatenation of positive and negative (or empty) noises. + guidance_scale: The scale factor for classifier-free guidance. + + Reference: + - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, inputs, guidance_scale): + positive_noise, negative_noise = ops.split(inputs, 2, axis=0) + return ops.add( + negative_noise, + ops.multiply( + guidance_scale, ops.subtract(positive_noise, negative_noise) + ), + ) + + def get_config(self): + return super().get_config() + + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + if outputs_shape[0] is not None: + outputs_shape[0] = outputs_shape[0] // 2 + return outputs_shape + + +class EulerStep(layers.Layer): + """A layer predicts the sample with the timestep and the predicted noise. + + Args: + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Call arguments: + latents: A current sample created by the diffusion process. + noise_residual: The direct output from the diffusion model. + sigma: The amount of noise added at the current timestep. + sigma_next: The amount of noise added at the next timestep. + + References: + - [Common Diffusion Noise Schedules and Sample Steps are Flawed]( + https://arxiv.org/abs/2305.08891). + - [Elucidating the Design Space of Diffusion-Based Generative Models]( + https://arxiv.org/abs/2206.00364). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, latents, noise_residual, sigma, sigma_next): + sigma_diff = ops.subtract(sigma_next, sigma) + return ops.add(latents, ops.multiply(sigma_diff, noise_residual)) + + def get_config(self): + return super().get_config() + + def compute_output_shape(self, latents_shape): + return latents_shape + + +class LatentSpaceDecoder(layers.Layer): + """Decoder to transform the latent space back to the original image space. + + During decoding, the latents are transformed back to the original image + space using the equation: `latents / scale + shift`. + + Args: + scale: float. The scaling factor. + shift: float. The shift factor. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Call arguments: + latents: The latent tensor to be transformed. + + Reference: + - [High-Resolution Image Synthesis with Latent Diffusion Models]( + https://arxiv.org/abs/2112.10752). + """ + + def __init__(self, scale, shift, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.shift = shift + + def call(self, latents): + return ops.add(ops.divide(latents, self.scale), self.shift) + + def get_config(self): + config = super().get_config() + config.update( + { + "scale": self.scale, + "shift": self.shift, + } + ) + return config + + def compute_output_shape(self, latents_shape): + return latents_shape + + +@keras_hub_export("keras_hub.models.StableDiffusion3Backbone") +class StableDiffusion3Backbone(Backbone): + """Stable Diffusion 3 core network with hyperparameters. + + This backbone imports CLIP and T5 models as text encoders and implements the + base MMDiT and VAE networks for the Stable Diffusion 3 model. + + The default constructor gives a fully customizable, randomly initialized + MMDiT and VAE models with any hyperparameters. To load preset architectures + and weights, use the `from_preset` constructor. + + Args: + mmdit_patch_size: int. The size of each square patch in the input image + in MMDiT. + mmdit_hidden_dim: int. The size of the transformer hidden state at the + end of each transformer layer in MMDiT. + mmdit_num_layers: int. The number of transformer layers in MMDiT. + mmdit_num_heads: int. The number of attention heads for each + transformer in MMDiT. + mmdit_position_size: int. The size of the height and width for the + position embedding in MMDiT. + vae_stackwise_num_filters: list of ints. The number of filters for each + stack in VAE. + vae_stackwise_num_blocks: list of ints. The number of blocks for each + stack in VAE. + clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for + encoding the inputs. + clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for + encoding the inputs. + t5: optional `keras_hub.models.T5Encoder`. The text encoder for + encoding the inputs. + latent_channels: int. The number of channels in the latent. Defaults to + `16`. + output_channels: int. The number of channels in the output. Defaults to + `3`. + num_train_timesteps: int. The number of diffusion steps to train the + model. Defaults to `1000`. + shift: float. The shift value for the timestep schedule. Defaults to + `1.0`. + height: optional int. The output height of the image. + width: optional int. The output width of the image. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example: + ```python + # Pretrained Stable Diffusion 3 model. + model = keras_hub.models.StableDiffusion3Backbone.from_preset( + "stable_diffusion_3_medium" + ) + + # Randomly initialized Stable Diffusion 3 model with custom config. + clip_l = keras_hub.models.CLIPTextEncoder(...) + clip_g = keras_hub.models.CLIPTextEncoder(...) + model = keras_hub.models.StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_num_heads=4, + mmdit_hidden_dim=256, + mmdit_depth=4, + mmdit_position_size=192, + vae_stackwise_num_filters=[128, 128, 64, 32], + vae_stackwise_num_blocks=[1, 1, 1, 1], + clip_l=clip_l, + clip_g=clip_g, + ) + ``` + """ + + def __init__( + self, + mmdit_patch_size, + mmdit_hidden_dim, + mmdit_num_layers, + mmdit_num_heads, + mmdit_position_size, + vae_stackwise_num_filters, + vae_stackwise_num_blocks, + clip_l, + clip_g, + t5=None, + latent_channels=16, + output_channels=3, + num_train_timesteps=1000, + shift=1.0, + height=None, + width=None, + data_format=None, + dtype=None, + **kwargs, + ): + height = int(height or 1024) + width = int(width or 1024) + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + "`height` and `width` must be divisible by 8. " + f"Received: height={height}, width={width}" + ) + data_format = standardize_data_format(data_format) + if data_format != "channels_last": + raise NotImplementedError + latent_shape = (height // 8, width // 8, latent_channels) + context_shape = (None, 4096 if t5 is None else t5.hidden_dim) + pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,) + + # === Layers === + self.clip_l = clip_l + self.clip_l_projection = CLIPProjection( + clip_l.hidden_dim, dtype=dtype, name="clip_l_projection" + ) + self.clip_l_projection.build([None, clip_l.hidden_dim], None) + self.clip_g = clip_g + self.clip_g_projection = CLIPProjection( + clip_g.hidden_dim, dtype=dtype, name="clip_g_projection" + ) + self.clip_g_projection.build([None, clip_g.hidden_dim], None) + self.t5 = t5 + self.diffuser = MMDiT( + mmdit_patch_size, + mmdit_hidden_dim, + mmdit_num_layers, + mmdit_num_heads, + mmdit_position_size, + latent_shape=latent_shape, + context_shape=context_shape, + pooled_projection_shape=pooled_projection_shape, + data_format=data_format, + dtype=dtype, + name="diffuser", + ) + self.decoder = VAEImageDecoder( + vae_stackwise_num_filters, + vae_stackwise_num_blocks, + output_channels, + latent_shape=latent_shape, + data_format=data_format, + dtype=dtype, + name="decoder", + ) + # Set `dtype="float32"` to ensure the high precision for the noise + # residual. + self.scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=num_train_timesteps, + shift=shift, + dtype="float32", + name="scheduler", + ) + self.cfg_concat = ClassifierFreeGuidanceConcatenate( + dtype="float32", name="classifier_free_guidance_concat" + ) + self.cfg = ClassifierFreeGuidance( + dtype="float32", name="classifier_free_guidance" + ) + self.euler_step = EulerStep(dtype="float32", name="euler_step") + self.latent_space_decoder = LatentSpaceDecoder( + scale=self.decoder.scaling_factor, + shift=self.decoder.shift_factor, + dtype="float32", + name="latent_space_decoder", + ) + + # === Functional Model === + latent_input = keras.Input( + shape=latent_shape, + name="latents", + ) + clip_l_token_id_input = keras.Input( + shape=(None,), + dtype="int32", + name="clip_l_token_ids", + ) + clip_l_negative_token_id_input = keras.Input( + shape=(None,), + dtype="int32", + name="clip_l_negative_token_ids", + ) + clip_g_token_id_input = keras.Input( + shape=(None,), + dtype="int32", + name="clip_g_token_ids", + ) + clip_g_negative_token_id_input = keras.Input( + shape=(None,), + dtype="int32", + name="clip_g_negative_token_ids", + ) + token_ids = { + "clip_l": clip_l_token_id_input, + "clip_g": clip_g_token_id_input, + } + negative_token_ids = { + "clip_l": clip_l_negative_token_id_input, + "clip_g": clip_g_negative_token_id_input, + } + if self.t5 is not None: + t5_token_id_input = keras.Input( + shape=(None,), + dtype="int32", + name="t5_token_ids", + ) + t5_negative_token_id_input = keras.Input( + shape=(None,), + dtype="int32", + name="t5_negative_token_ids", + ) + token_ids["t5"] = t5_token_id_input + negative_token_ids["t5"] = t5_negative_token_id_input + num_step_input = keras.Input( + shape=(), + dtype="int32", + name="num_steps", + ) + guidance_scale_input = keras.Input( + shape=(), + dtype="float32", + name="guidance_scale", + ) + embeddings = self.encode_step(token_ids, negative_token_ids) + # Use `steps=0` to define the functional model. + latents = self.denoise_step( + latent_input, + embeddings, + 0, + num_step_input[0], + guidance_scale_input[0], + ) + outputs = self.decode_step(latents) + inputs = { + "latents": latent_input, + "clip_l_token_ids": clip_l_token_id_input, + "clip_l_negative_token_ids": clip_l_negative_token_id_input, + "clip_g_token_ids": clip_g_token_id_input, + "clip_g_negative_token_ids": clip_g_negative_token_id_input, + "num_steps": num_step_input, + "guidance_scale": guidance_scale_input, + } + if self.t5 is not None: + inputs["t5_token_ids"] = t5_token_id_input + inputs["t5_negative_token_ids"] = t5_negative_token_id_input + super().__init__( + inputs=inputs, + outputs=outputs, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.mmdit_patch_size = mmdit_patch_size + self.mmdit_hidden_dim = mmdit_hidden_dim + self.mmdit_num_layers = mmdit_num_layers + self.mmdit_num_heads = mmdit_num_heads + self.mmdit_position_size = mmdit_position_size + self.vae_stackwise_num_filters = vae_stackwise_num_filters + self.vae_stackwise_num_blocks = vae_stackwise_num_blocks + self.latent_channels = latent_channels + self.output_channels = output_channels + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.height = height + self.width = width + + @property + def latent_shape(self): + return (None,) + tuple(self.diffuser.latent_shape) + + @property + def clip_hidden_dim(self): + return self.clip_l.hidden_dim + self.clip_g.hidden_dim + + @property + def t5_hidden_dim(self): + return 4096 if self.t5 is None else self.t5.hidden_dim + + def encode_step(self, token_ids, negative_token_ids): + clip_hidden_dim = self.clip_hidden_dim + t5_hidden_dim = self.t5_hidden_dim + + def encode(token_ids): + clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False) + clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False) + clip_l_projection = self.clip_l_projection( + clip_l_outputs["sequence_output"], + token_ids["clip_l"], + training=False, + ) + clip_g_projection = self.clip_g_projection( + clip_g_outputs["sequence_output"], + token_ids["clip_g"], + training=False, + ) + pooled_embeddings = ops.concatenate( + [clip_l_projection, clip_g_projection], + axis=-1, + ) + embeddings = ops.concatenate( + [ + clip_l_outputs["intermediate_output"], + clip_g_outputs["intermediate_output"], + ], + axis=-1, + ) + embeddings = ops.pad( + embeddings, + [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]], + ) + if self.t5 is not None: + t5_outputs = self.t5(token_ids["t5"], training=False) + embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2) + else: + padded_size = self.clip_l.max_sequence_length + embeddings = ops.pad( + embeddings, [[0, 0], [0, padded_size], [0, 0]] + ) + return embeddings, pooled_embeddings + + positive_embeddings, positive_pooled_embeddings = encode(token_ids) + negative_embeddings, negative_pooled_embeddings = encode( + negative_token_ids + ) + return ( + positive_embeddings, + negative_embeddings, + positive_pooled_embeddings, + negative_pooled_embeddings, + ) + + def denoise_step( + self, + latents, + embeddings, + steps, + num_steps, + guidance_scale, + ): + steps = ops.convert_to_tensor(steps) + steps_next = ops.add(steps, 1) + sigma, timestep = self.scheduler(steps, num_steps) + sigma_next, _ = self.scheduler(steps_next, num_steps) + + # Concatenation for classifier-free guidance. + concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat( + latents, *embeddings, timestep + ) + + # Diffusion. + predicted_noise = self.diffuser( + { + "latent": concated_latents, + "context": contexts, + "pooled_projection": pooled_projs, + "timestep": timesteps, + }, + training=False, + ) + + # Classifier-free guidance. + predicted_noise = self.cfg(predicted_noise, guidance_scale) + + # Euler step. + return self.euler_step(latents, predicted_noise, sigma, sigma_next) + + def decode_step(self, latents): + latents = self.latent_space_decoder(latents) + return self.decoder(latents, training=False) + + def get_config(self): + config = super().get_config() + config.update( + { + "mmdit_patch_size": self.mmdit_patch_size, + "mmdit_hidden_dim": self.mmdit_hidden_dim, + "mmdit_num_layers": self.mmdit_num_layers, + "mmdit_num_heads": self.mmdit_num_heads, + "mmdit_position_size": self.mmdit_position_size, + "vae_stackwise_num_filters": self.vae_stackwise_num_filters, + "vae_stackwise_num_blocks": self.vae_stackwise_num_blocks, + "clip_l": layers.serialize(self.clip_l), + "clip_g": layers.serialize(self.clip_g), + "t5": layers.serialize(self.t5), + "latent_channels": self.latent_channels, + "output_channels": self.output_channels, + "num_train_timesteps": self.num_train_timesteps, + "shift": self.shift, + "height": self.height, + "width": self.width, + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated. + config = config.copy() + config["clip_l"] = layers.deserialize( + config["clip_l"], custom_objects=custom_objects + ) + config["clip_g"] = layers.deserialize( + config["clip_g"], custom_objects=custom_objects + ) + if config["t5"] is not None: + config["t5"] = layers.deserialize( + config["t5"], custom_objects=custom_objects + ) + return cls(**config) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py new file mode 100644 index 0000000000..64172e3444 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3BackboneTest(TestCase): + def setUp(self): + clip_l = CLIPTextEncoder( + 20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l" + ) + clip_g = CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "gelu", -2, name="clip_g" + ) + self.init_kwargs = { + "mmdit_patch_size": 2, + "mmdit_hidden_dim": 16 * 2, + "mmdit_num_layers": 2, + "mmdit_num_heads": 2, + "mmdit_position_size": 192, + "vae_stackwise_num_filters": [32, 32, 32, 32], + "vae_stackwise_num_blocks": [1, 1, 1, 1], + "clip_l": clip_l, + "clip_g": clip_g, + "height": 64, + "width": 64, + } + self.input_data = { + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=StableDiffusion3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 64, 64, 3), + # Since `clip_l` and `clip_g` were instantiated outside of + # `StableDiffusion3Backbone`, the mixed precision and + # quantization checks will fail. + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py new file mode 100644 index 0000000000..94a0de2214 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py @@ -0,0 +1,151 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.text_to_image import TextToImage + + +@keras_hub_export("keras_hub.models.StableDiffusion3TextToImage") +class StableDiffusion3TextToImage(TextToImage): + """An end-to-end Stable Diffusion 3 model for text-to-image generation. + + This model has a `generate()` method, which generates image based on a + prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( + "stable_diffusion_3_medium", height=512, width=512 + ) + text_to_image.generate( + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + ) + + # Generate with batched prompts. + text_to_image.generate( + ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + ) + + # Generate with different `num_steps` and `classifier_free_guidance_scale`. + text_to_image.generate( + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + num_steps=50, + classifier_free_guidance_scale=5.0, + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3TextToImage`." + ) + + def generate_step( + self, + latents, + token_ids, + negative_token_ids, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + latents: A (batch_size, height, width, channels) tensor + containing the latents to start generation from. Typically, this + tensor is sampled from the Gaussian distribution. + token_ids: A (batch_size, num_tokens) tensor containing the + tokens based on the input prompts. + negative_token_ids: A (batch_size, num_tokens) tensor + containing the negative tokens based on the input prompts. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + # Encode inputs. + embeddings = self.backbone.encode_step(token_ids, negative_token_ids) + + # Denoise. + def body_fun(step, latents): + return self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + latents = ops.fori_loop(0, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + inputs, + negative_inputs=None, + num_steps=28, + guidance_scale=7.0, + seed=None, + ): + return super().generate( + inputs, + negative_inputs=negative_inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py new file mode 100644 index 0000000000..2a0656bdf4 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py @@ -0,0 +1,77 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.preprocessor import Preprocessor + + +@keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor") +class StableDiffusion3TextToImagePreprocessor(Preprocessor): + """Stable Diffusion 3 text-to-image model preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.StableDiffusion3TextToImage`. + + For use with generation, the layer exposes one methods + `generate_preprocess()`. + + Args: + clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance. + clip_g_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance. + t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance. + """ + + def __init__( + self, + clip_l_preprocessor, + clip_g_preprocessor, + t5_preprocessor=None, + **kwargs, + ): + super().__init__(**kwargs) + self.clip_l_preprocessor = clip_l_preprocessor + self.clip_g_preprocessor = clip_g_preprocessor + self.t5_preprocessor = t5_preprocessor + + def build(self, input_shape): + self.built = True + + def generate_preprocess(self, x): + token_ids = {} + token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"] + token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"] + if self.t5_preprocessor is not None: + token_ids["t5"] = self.t5_preprocessor(x)["token_ids"] + return token_ids + + def get_config(self): + config = super().get_config() + config.update( + { + "clip_l_preprocessor": layers.serialize( + self.clip_l_preprocessor + ), + "clip_g_preprocessor": layers.serialize( + self.clip_g_preprocessor + ), + "t5_preprocessor": layers.serialize(self.t5_preprocessor), + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self.clip_l_preprocessor.sequence_length diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py new file mode 100644 index 0000000000..58d2bb1a49 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py @@ -0,0 +1,73 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3TextToImagePreprocessorTest(TestCase): + def setUp(self): + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer( + vocabulary=vocab, merges=merges, pad_with_end_token=True + ) + clip_g_tokenizer = CLIPTokenizer(vocabulary=vocab, merges=merges) + clip_l_preprocessor = CLIPPreprocessor( + clip_l_tokenizer, sequence_length=8 + ) + clip_g_preprocessor = CLIPPreprocessor( + clip_g_tokenizer, sequence_length=8 + ) + self.init_kwargs = { + "clip_l_preprocessor": clip_l_preprocessor, + "clip_g_preprocessor": clip_g_preprocessor, + } + self.input_data = ["airplane"] + + def test_preprocessor_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_preprocessing_layer_test( + cls=StableDiffusion3TextToImagePreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_generate_preprocess(self): + preprocessor = StableDiffusion3TextToImagePreprocessor( + **self.init_kwargs + ) + x = preprocessor.generate_preprocess(self.input_data) + self.assertIn("clip_l", x) + self.assertIn("clip_g", x) + self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3]) + self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 3, 3, 3, 3]) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py new file mode 100644 index 0000000000..570f8c3e6b --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -0,0 +1,155 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( + StableDiffusion3TextToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3TextToImageTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae_stackwise_num_filters=[32, 32, 32, 32], + vae_stackwise_num_blocks=[1, 1, 1, 1], + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + height=128, + width=128, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "latents": ops.ones((2, 16, 16, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_text_to_image_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3TextToImage, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape=(2, 128, 128, 3), + ) + + def test_generate(self): + text_to_image = StableDiffusion3TextToImage(**self.init_kwargs) + seed = 42 + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = text_to_image.generate(prompt, negative_prompt, seed=seed) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + text_to_image.preprocessor = None + output2 = text_to_image.generate( + prompt_ids, negative_prompt_ids, seed=seed + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + text_to_image = StableDiffusion3TextToImage(**self.init_kwargs) + seed = 42 + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = text_to_image.generate( + prompt, negative_prompt, seed=seed + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + text_to_image.preprocessor = None + output2 = text_to_image.generate( + prompt_ids, negative_prompt_ids, seed=seed + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + text_to_image = StableDiffusion3TextToImage(**self.init_kwargs) + # Assert we do not recompile with successive calls. + text_to_image.generate("airplane") + first_fn = text_to_image.generate_function + text_to_image.generate("airplane") + second_fn = text_to_image.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + text_to_image.compile() + self.assertIsNone(text_to_image.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3TextToImage, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py b/keras_hub/src/models/stable_diffusion_3/t5_encoder.py similarity index 93% rename from keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py rename to keras_hub/src/models/stable_diffusion_3/t5_encoder.py index 0a137551d0..9e9c261c90 100644 --- a/keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +++ b/keras_hub/src/models/stable_diffusion_3/t5_encoder.py @@ -20,7 +20,7 @@ from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer -class T5XXLTextEncoder(keras.Model): +class T5Encoder(keras.Model): def __init__( self, vocabulary_size, @@ -81,10 +81,10 @@ def __init__( # === Functional Model === encoder_token_id_input = keras.Input( - shape=(None,), dtype="int32", name="encoder_token_ids" + shape=(None,), dtype="int32", name="token_ids" ) encoder_padding_mask_input = keras.Input( - shape=(None,), dtype="int32", name="encoder_padding_mask" + shape=(None,), dtype="int32", name="padding_mask" ) # Encoder. x = self.token_embedding(encoder_token_id_input) @@ -102,14 +102,14 @@ def __init__( x, position_bias = output x = self.encoder_layer_norm(x) x = self.encoder_dropout(x) - encoder_output = x + sequence_output = x super().__init__( { - "encoder_token_ids": encoder_token_id_input, - "encoder_padding_mask": encoder_padding_mask_input, + "token_ids": encoder_token_id_input, + "padding_mask": encoder_padding_mask_input, }, - outputs=encoder_output, + outputs=sequence_output, **kwargs, ) diff --git a/keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py b/keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py new file mode 100644 index 0000000000..02b14e77e5 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py @@ -0,0 +1,333 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras import layers +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class VAEAttention(layers.Layer): + def __init__(self, filters, groups=32, data_format=None, **kwargs): + super().__init__(**kwargs) + self.filters = filters + self.data_format = standardize_data_format(data_format) + gn_axis = -1 if self.data_format == "channels_last" else 1 + + self.group_norm = layers.GroupNormalization( + groups=groups, + axis=gn_axis, + epsilon=1e-6, + dtype="float32", + name="group_norm", + ) + self.query_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="query_conv2d", + ) + self.key_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="key_conv2d", + ) + self.value_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="value_conv2d", + ) + self.softmax = layers.Softmax(dtype="float32") + self.output_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="output_conv2d", + ) + + self.groups = groups + self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) + + def build(self, input_shape): + self.group_norm.build(input_shape) + self.query_conv2d.build(input_shape) + self.key_conv2d.build(input_shape) + self.value_conv2d.build(input_shape) + self.output_conv2d.build(input_shape) + + def call(self, inputs, training=None): + x = self.group_norm(inputs) + query = self.query_conv2d(x) + key = self.key_conv2d(x) + value = self.value_conv2d(x) + + if self.data_format == "channels_first": + query = ops.transpose(query, (0, 2, 3, 1)) + key = ops.transpose(key, (0, 2, 3, 1)) + value = ops.transpose(value, (0, 2, 3, 1)) + shape = ops.shape(inputs) + b = shape[0] + query = ops.reshape(query, (b, -1, self.filters)) + key = ops.reshape(key, (b, -1, self.filters)) + value = ops.reshape(value, (b, -1, self.filters)) + + # Compute attention. + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_filters, query.dtype) + ) + # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] + attention_scores = ops.einsum("abc,adc->abd", query, key) + attention_scores = ops.cast( + self.softmax(attention_scores), self.compute_dtype + ) + # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] + attention_output = ops.einsum("abc,adb->adc", value, attention_scores) + x = ops.reshape(attention_output, shape) + + x = self.output_conv2d(x) + if self.data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + x = ops.add(x, inputs) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "groups": self.groups, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None): + data_format = standardize_data_format(data_format) + gn_axis = -1 if data_format == "channels_last" else 1 + input_filters = x.shape[gn_axis] + + residual = x + x = layers.GroupNormalization( + groups=32, + axis=gn_axis, + epsilon=1e-6, + dtype="float32", + name=f"{name}_norm1", + )(x) + x = layers.Activation("swish", dtype=dtype)(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_conv1", + )(x) + x = layers.GroupNormalization( + groups=32, + axis=gn_axis, + epsilon=1e-6, + dtype="float32", + name=f"{name}_norm2", + )(x) + x = layers.Activation("swish", dtype=dtype)(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_conv2", + )(x) + if input_filters != filters: + residual = layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=dtype, + name=f"{name}_residual_projection", + )(residual) + x = layers.Add(dtype=dtype)([residual, x]) + return x + + +class VAEImageDecoder(Backbone): + """Decoder for the VAE model used in Stable Diffusion 3. + + Args: + stackwise_num_filters: list of ints. The number of filters for each + stack. + stackwise_num_blocks: list of ints. The number of blocks for each stack. + output_channels: int. The number of channels in the output. + latent_shape: tuple. The shape of the latent image. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + output_channels=3, + latent_shape=(None, None, 16), + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + gn_axis = -1 if data_format == "channels_last" else 1 + + # === Functional Model === + latent_inputs = layers.Input(shape=latent_shape) + + x = layers.Conv2D( + stackwise_num_filters[0], + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name="input_projection", + )(latent_inputs) + x = apply_resnet_block( + x, + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_block0", + ) + x = VAEAttention( + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_attention", + )(x) + x = apply_resnet_block( + x, + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_block1", + ) + + # Stacks. + for i, filters in enumerate(stackwise_num_filters): + for j in range(stackwise_num_blocks[i]): + x = apply_resnet_block( + x, + filters, + data_format=data_format, + dtype=dtype, + name=f"block{i}_{j}", + ) + if i != len(stackwise_num_filters) - 1: + # No upsamling in the last blcok. + x = layers.UpSampling2D( + 2, + data_format=data_format, + dtype=dtype, + name=f"upsample_{i}", + )(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"upsample_{i}_conv", + )(x) + + # Ouput block. + x = layers.GroupNormalization( + groups=32, + axis=gn_axis, + epsilon=1e-6, + dtype="float32", + name="output_norm", + )(x) + x = layers.Activation("swish", dtype=dtype, name="output_activation")(x) + image_outputs = layers.Conv2D( + output_channels, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name="output_projection", + )(x) + super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.output_channels = output_channels + self.latent_shape = latent_shape + + @property + def scaling_factor(self): + """The scaling factor for the latent space. + + This is used to scale the latent space to have unit variance when + training the diffusion model. + """ + return 1.5305 + + @property + def shift_factor(self): + """The shift factor for the latent space. + + This is used to shift the latent space to have zero mean when + training the diffusion model. + """ + return 0.0609 + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "output_channels": self.output_channels, + "image_shape": self.latent_shape, + } + ) + return config diff --git a/keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py b/keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py deleted file mode 100644 index f7cfa461d5..0000000000 --- a/keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import keras - -from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker -from keras_hub.src.models.preprocessor import Preprocessor -from keras_hub.src.models.stable_diffusion_v3.clip_tokenizer import ( - CLIPTokenizer, -) -from keras_hub.src.utils.tensor_utils import preprocessing_function - -try: - import tensorflow as tf -except ImportError: - tf = None - - -class CLIPPreprocessor(Preprocessor): - tokenizer_cls = CLIPTokenizer - - def __init__( - self, - tokenizer, - sequence_length=77, - add_start_token=True, - add_end_token=False, - to_lower=True, - pad_with_end_token=True, - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.sequence_length = sequence_length - self.add_start_token = add_start_token - self.add_end_token = add_end_token - self.to_lower = to_lower - self.pad_with_end_token = pad_with_end_token - - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - pad_value = self.tokenizer.pad_token_id - if self.pad_with_end_token: - pad_value = self.tokenizer.end_token_id - - self.packer = StartEndPacker( - start_value=self.tokenizer.start_token_id, - end_value=self.tokenizer.end_token_id, - pad_value=pad_value, - sequence_length=self.sequence_length, - return_padding_mask=True, - ) - self.built = True - - @preprocessing_function - def call(self, x, y=None, sample_weight=None, sequence_length=None): - if self.to_lower: - x = tf.strings.lower(x) - token_ids, padding_mask = self.packer( - self.tokenizer(x), - sequence_length=sequence_length or self.sequence_length, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - x = { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "add_start_token": self.add_start_token, - "add_end_token": self.add_end_token, - "to_lower": self.to_lower, - "pad_with_end_token": self.pad_with_end_token, - } - ) - return config diff --git a/keras_hub/src/models/stable_diffusion_v3/mmdit_block.py b/keras_hub/src/models/stable_diffusion_v3/mmdit_block.py deleted file mode 100644 index a46d4c1b86..0000000000 --- a/keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -from keras import layers -from keras import models -from keras import ops - -from keras_hub.src.utils.keras_utils import gelu_approximate - - -class DismantledBlock(layers.Layer): - def __init__( - self, - num_heads, - hidden_dim, - mlp_ratio=4.0, - use_projection=True, - **kwargs, - ): - super().__init__(**kwargs) - self.num_heads = num_heads - self.hidden_dim = hidden_dim - self.mlp_ratio = mlp_ratio - self.use_projection = use_projection - - head_dim = hidden_dim // num_heads - self.head_dim = head_dim - mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp_hidden_dim = mlp_hidden_dim - num_modulations = 6 if use_projection else 2 - self.num_modulations = num_modulations - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulations * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm1 = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype=self.dtype_policy, - name="norm1", - ) - self.attention_qkv = layers.Dense( - hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" - ) - if use_projection: - self.attention_proj = layers.Dense( - hidden_dim, dtype=self.dtype_policy, name="attention_proj" - ) - self.norm2 = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype=self.dtype_policy, - name="norm2", - ) - self.mlp = models.Sequential( - [ - layers.Dense( - mlp_hidden_dim, - activation=gelu_approximate, - dtype=self.dtype_policy, - ), - layers.Dense( - hidden_dim, - dtype=self.dtype_policy, - ), - ], - name="mlp", - ) - - def build(self, inputs_shape, timestep_embedding): - self.adaptive_norm_modulation.build(timestep_embedding) - self.attention_qkv.build(inputs_shape) - self.norm1.build(inputs_shape) - if self.use_projection: - self.attention_proj.build(inputs_shape) - self.norm2.build(inputs_shape) - self.mlp.build(inputs_shape) - - def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) - return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) - - def _compute_pre_attention(self, inputs, timestep_embedding, training=None): - batch_size = ops.shape(inputs)[0] - if self.use_projection: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 6, self.hidden_dim) - ) - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = ops.unstack(modulation, 6, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, - ) - qkv = ops.reshape( - qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) - ) - q, k, v = ops.unstack(qkv, 3, axis=2) - return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) - else: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 2, self.hidden_dim) - ) - shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, - ) - qkv = ops.reshape( - qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) - ) - q, k, v = ops.unstack(qkv, 3, axis=2) - return (q, k, v) - - def _compute_post_attention( - self, inputs, inputs_intermediates, training=None - ): - x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates - attn = self.attention_proj(inputs, training=training) - x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn)) - x = ops.add( - x, - ops.multiply( - ops.expand_dims(gate_mlp, axis=1), - self.mlp( - self._modulate(self.norm2(x), shift_mlp, scale_mlp), - training=training, - ), - ), - ) - return x - - def call( - self, - inputs, - timestep_embedding=None, - inputs_intermediates=None, - pre_attention=True, - training=None, - ): - if pre_attention: - return self._compute_pre_attention( - inputs, timestep_embedding, training=training - ) - else: - return self._compute_post_attention( - inputs, inputs_intermediates, training=training - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "mlp_ratio": self.mlp_ratio, - "use_projection": self.use_projection, - } - ) - return config - - -class MMDiTBlock(layers.Layer): - def __init__( - self, - num_heads, - hidden_dim, - mlp_ratio=4.0, - use_context_projection=True, - **kwargs, - ): - super().__init__(**kwargs) - self.num_heads = num_heads - self.hidden_dim = hidden_dim - self.mlp_ratio = mlp_ratio - self.use_context_projection = use_context_projection - - head_dim = hidden_dim // num_heads - self.head_dim = head_dim - self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim) - self._dot_product_equation = "aecd,abcd->acbe" - self._combine_equation = "acbe,aecd->abcd" - - self.x_block = DismantledBlock( - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_ratio=mlp_ratio, - use_projection=True, - dtype=self.dtype_policy, - name="x_block", - ) - self.context_block = DismantledBlock( - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_ratio=mlp_ratio, - use_projection=use_context_projection, - dtype=self.dtype_policy, - name="context_block", - ) - - def build(self, inputs_shape, context_shape, timestep_embedding_shape): - self.x_block.build(inputs_shape, timestep_embedding_shape) - self.context_block.build(context_shape, timestep_embedding_shape) - - def _compute_attention(self, query, key, value): - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = ops.nn.softmax(attention_scores, axis=-1) - attention_output = ops.einsum( - self._combine_equation, attention_scores, value - ) - batch_size = ops.shape(attention_output)[0] - attention_output = ops.reshape( - attention_output, (batch_size, -1, self.num_heads * self.head_dim) - ) - return attention_output - - def call(self, inputs, context, timestep_embedding, training=None): - # Compute pre-attention. - x = inputs - if self.use_context_projection: - context_qkv, context_intermediates = self.context_block( - context, - timestep_embedding=timestep_embedding, - training=training, - ) - else: - context_qkv = self.context_block( - context, - timestep_embedding=timestep_embedding, - training=training, - ) - context_len = ops.shape(context_qkv[0])[1] - x_qkv, x_intermediates = self.x_block( - x, timestep_embedding=timestep_embedding, training=training - ) - q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1) - k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1) - v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1) - - # Compute attention. - attention = self._compute_attention(q, k, v) - context_attention = attention[:, :context_len] - x_attention = attention[:, context_len:] - - # Compute post-attention. - x = self.x_block( - x_attention, - inputs_intermediates=x_intermediates, - pre_attention=False, - training=training, - ) - if self.use_context_projection: - context = self.context_block( - context_attention, - inputs_intermediates=context_intermediates, - pre_attention=False, - training=training, - ) - return x, context - else: - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "mlp_ratio": self.mlp_ratio, - "use_context_projection": self.use_context_projection, - } - ) - return config - - def compute_output_shape( - self, inputs_shape, context_shape, timestep_embedding_shape - ): - if self.use_context_projection: - return inputs_shape, context_shape - else: - return inputs_shape diff --git a/keras_hub/src/models/stable_diffusion_v3/vae_attention.py b/keras_hub/src/models/stable_diffusion_v3/vae_attention.py deleted file mode 100644 index f9f2239f05..0000000000 --- a/keras_hub/src/models/stable_diffusion_v3/vae_attention.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -from keras import layers -from keras import ops - -from keras_hub.src.utils.keras_utils import standardize_data_format - - -class VAEAttention(layers.Layer): - def __init__(self, filters, groups=32, data_format=None, **kwargs): - super().__init__(**kwargs) - self.filters = filters - self.data_format = standardize_data_format(data_format) - gn_axis = -1 if self.data_format == "channels_last" else 1 - - self.group_norm = layers.GroupNormalization( - groups=groups, - axis=gn_axis, - epsilon=1e-6, - dtype=self.dtype_policy, - name="group_norm", - ) - self.query_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="query_conv2d", - ) - self.key_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="key_conv2d", - ) - self.value_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="value_conv2d", - ) - self.softmax = layers.Softmax(dtype="float32") - self.output_conv2d = layers.Conv2D( - filters, - 1, - 1, - data_format=self.data_format, - dtype=self.dtype_policy, - name="output_conv2d", - ) - - self.groups = groups - self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) - - def build(self, input_shape): - self.group_norm.build(input_shape) - self.query_conv2d.build(input_shape) - self.key_conv2d.build(input_shape) - self.value_conv2d.build(input_shape) - self.output_conv2d.build(input_shape) - - def call(self, inputs, training=None): - x = self.group_norm(inputs) - query = self.query_conv2d(x) - key = self.key_conv2d(x) - value = self.value_conv2d(x) - - if self.data_format == "channels_first": - query = ops.transpose(query, (0, 2, 3, 1)) - key = ops.transpose(key, (0, 2, 3, 1)) - value = ops.transpose(value, (0, 2, 3, 1)) - shape = ops.shape(inputs) - b = shape[0] - query = ops.reshape(query, (b, -1, self.filters)) - key = ops.reshape(key, (b, -1, self.filters)) - value = ops.reshape(value, (b, -1, self.filters)) - - # Compute attention. - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_filters, query.dtype) - ) - # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] - attention_scores = ops.einsum("abc,adc->abd", query, key) - attention_scores = ops.cast( - self.softmax(attention_scores), self.compute_dtype - ) - # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] - attention_output = ops.einsum("abc,adb->adc", value, attention_scores) - x = ops.reshape(attention_output, shape) - - x = self.output_conv2d(x) - if self.data_format == "channels_first": - x = ops.transpose(x, (0, 3, 1, 2)) - x = ops.add(x, inputs) - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "filters": self.filters, - "groups": self.groups, - } - ) - return config - - def compute_output_shape(self, input_shape): - return input_shape diff --git a/keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py b/keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py deleted file mode 100644 index 09afbef614..0000000000 --- a/keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import keras -from keras import layers - -from keras_hub.src.models.stable_diffusion_v3.vae_attention import VAEAttention -from keras_hub.src.utils.keras_utils import standardize_data_format - - -class VAEImageDecoder(keras.Model): - def __init__( - self, - stackwise_num_filters, - stackwise_num_blocks, - output_channels=3, - latent_shape=(None, None, 16), - data_format=None, - dtype=None, - **kwargs, - ): - data_format = standardize_data_format(data_format) - gn_axis = -1 if data_format == "channels_last" else 1 - - # === Functional Model === - latent_inputs = layers.Input(shape=latent_shape) - - x = layers.Conv2D( - stackwise_num_filters[0], - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name="input_projection", - )(latent_inputs) - x = apply_resnet_block( - x, - stackwise_num_filters[0], - data_format=data_format, - dtype=dtype, - name="input_block0", - ) - x = VAEAttention( - stackwise_num_filters[0], - data_format=data_format, - dtype=dtype, - name="input_attention", - )(x) - x = apply_resnet_block( - x, - stackwise_num_filters[0], - data_format=data_format, - dtype=dtype, - name="input_block1", - ) - - # Stacks. - for i, filters in enumerate(stackwise_num_filters): - for j in range(stackwise_num_blocks[i]): - x = apply_resnet_block( - x, - filters, - data_format=data_format, - dtype=dtype, - name=f"block{i}_{j}", - ) - if i != len(stackwise_num_filters) - 1: - # No upsamling in the last blcok. - x = layers.UpSampling2D( - 2, - data_format=data_format, - dtype=dtype, - name=f"upsample_{i}", - )(x) - x = layers.Conv2D( - filters, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name=f"upsample_{i}_conv", - )(x) - - # Ouput block. - x = layers.GroupNormalization( - groups=32, - axis=gn_axis, - epsilon=1e-6, - dtype=dtype, - name="output_norm", - )(x) - x = layers.Activation("swish", dtype=dtype, name="output_activation")(x) - image_outputs = layers.Conv2D( - output_channels, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name="output_projection", - )(x) - super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs) - - # === Config === - self.stackwise_num_filters = stackwise_num_filters - self.stackwise_num_blocks = stackwise_num_blocks - self.output_channels = output_channels - self.latent_shape = latent_shape - - if dtype is not None: - try: - self.dtype_policy = keras.dtype_policies.get(dtype) - # Before Keras 3.2, there is no `keras.dtype_policies.get`. - except AttributeError: - if isinstance(dtype, keras.DTypePolicy): - dtype = dtype.name - self.dtype_policy = keras.DTypePolicy(dtype) - - def get_config(self): - config = super().get_config() - config.update( - { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_num_blocks": self.stackwise_num_blocks, - "output_channels": self.output_channels, - "image_shape": self.latent_shape, - } - ) - return config - - -def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None): - data_format = standardize_data_format(data_format) - gn_axis = -1 if data_format == "channels_last" else 1 - input_filters = x.shape[gn_axis] - - residual = x - x = layers.GroupNormalization( - groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1" - )(x) - x = layers.Activation("swish", dtype=dtype)(x) - x = layers.Conv2D( - filters, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name=f"{name}_conv1", - )(x) - x = layers.GroupNormalization( - groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2" - )(x) - x = layers.Activation("swish")(x) - x = layers.Conv2D( - filters, - 3, - 1, - padding="same", - data_format=data_format, - dtype=dtype, - name=f"{name}_conv2", - )(x) - if input_filters != filters: - residual = layers.Conv2D( - filters, - 1, - 1, - data_format=data_format, - dtype=dtype, - name=f"{name}_residual_projection", - )(residual) - x = layers.Add(dtype=dtype)([residual, x]) - return x diff --git a/keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py b/keras_hub/src/models/t5/t5_preprocessor.py similarity index 86% rename from keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py rename to keras_hub/src/models/t5/t5_preprocessor.py index 31fe9f38f6..4d7babe762 100644 --- a/keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +++ b/keras_hub/src/models/t5/t5_preprocessor.py @@ -13,13 +13,15 @@ # limitations under the License. import keras +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.utils.tensor_utils import preprocessing_function -class T5XXLPreprocessor(Preprocessor): +@keras_hub_export("keras_hub.models.T5Preprocessor") +class T5Preprocessor(Preprocessor): tokenizer_cls = T5Tokenizer def __init__( @@ -49,10 +51,17 @@ def build(self, input_shape): self.built = True @preprocessing_function - def call(self, x, y=None, sample_weight=None, sequence_length=None): + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + sequence_length = sequence_length or self.sequence_length token_ids, padding_mask = self.packer( self.tokenizer(x), - sequence_length=sequence_length or self.sequence_length, + sequence_length=sequence_length, add_start_value=self.add_start_token, add_end_value=self.add_end_token, ) diff --git a/keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py b/keras_hub/src/models/t5/t5_preprocessor_test.py similarity index 86% rename from keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py rename to keras_hub/src/models/t5/t5_preprocessor_test.py index acf97ab357..b2b084675a 100644 --- a/keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py +++ b/keras_hub/src/models/t5/t5_preprocessor_test.py @@ -15,14 +15,12 @@ import pytest -from keras_hub.src.models.stable_diffusion_v3.t5_xxl_preprocessor import ( - T5XXLPreprocessor, -) +from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.tests.test_case import TestCase -class T5XXLPreprocessorTest(TestCase): +class T5PreprocessorTest(TestCase): def setUp(self): self.tokenizer = T5Tokenizer( proto=os.path.join(self.get_test_data_dir(), "t5_test_vocab.spm") @@ -35,7 +33,7 @@ def setUp(self): def test_preprocessor_basics(self): self.run_preprocessing_layer_test( - cls=T5XXLPreprocessor, + cls=T5Preprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output={ @@ -46,7 +44,7 @@ def test_preprocessor_basics(self): def test_no_start_end_token(self): input_data = ["the quick brown fox"] * 4 - preprocessor = T5XXLPreprocessor( + preprocessor = T5Preprocessor( tokenizer=self.tokenizer, sequence_length=8, add_start_token=False, @@ -58,7 +56,7 @@ def test_no_start_end_token(self): def test_sequence_length_override(self): input_data = "the quick brown fox" - preprocessor = T5XXLPreprocessor(**self.init_kwargs) + preprocessor = T5Preprocessor(**self.init_kwargs) x = preprocessor(input_data, sequence_length=4) self.assertAllEqual(x["token_ids"], [4, 9, 5, 1]) @@ -66,9 +64,9 @@ def test_sequence_length_override(self): @pytest.mark.extra_large def test_all_presets(self): self.skipTest("TODO") - for preset in T5XXLPreprocessor.presets: + for preset in T5Preprocessor.presets: self.run_preset_test( - cls=T5XXLPreprocessor, + cls=T5Preprocessor, preset=preset, input_data=self.input_data, ) diff --git a/keras_hub/src/models/text_to_image.py b/keras_hub/src/models/text_to_image.py new file mode 100644 index 0000000000..8defec0a2f --- /dev/null +++ b/keras_hub/src/models/text_to_image.py @@ -0,0 +1,295 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.TextToImage") +class TextToImage(Task): + """Base class for text-to-image tasks. + + `TextToImage` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `TextToImage` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a string in, image out + signature. + + All `TextToImage` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + text_to_image = keras_hub.models.TextToImage.from_preset( + "stable_diffusion_3_medium", + ) + text_to_image.generate( + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + text_to_image = keras_hub.models.TextToImage.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + text_to_image.generate( + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `TextToImage` task for training. + + The `TextToImage` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.convert_to_numpy(outputs) + if input_is_scalar: + outputs = outputs[0] + return outputs + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + inputs, + negative_inputs, + num_steps, + guidance_scale, + seed=None, + ): + """Generate image based on the provided `inputs` and `negative_inputs`. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. + + Args: + inputs: python data, tensor data, or a `tf.data.Dataset`. + negative_inputs: python data, tensor data, or a `tf.data.Dataset`. + Unlike `inputs`, these are used as negative inputs to guide the + generation. If not provided, it defaults to `""` for each input + in `inputs`. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + seed: optional int. Used as a random seed. + """ + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + return self.preprocessor.generate_preprocess(x) + + # Normalize and preprocess inputs. + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if negative_inputs is None: + negative_inputs = [""] * len(inputs) + negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) + + if self.preprocessor is not None: + inputs = preprocess(inputs) + negative_inputs = preprocess(negative_inputs) + if isinstance(inputs, dict): + batch_size = len(inputs[list(inputs.keys())[0]]) + else: + batch_size = len(inputs) + + # Initialize random latents. + latent_shape = (batch_size,) + self.latent_shape[1:] + latents = random.normal(latent_shape, dtype="float32", seed=seed) + + # Text-to-image. + outputs = generate_function( + latents, + inputs, + negative_inputs, + num_steps, + guidance_scale, + ) + return self._normalize_generate_outputs(outputs, input_is_scalar)