Skip to content

Commit b9bc61e

Browse files
james77777778mattdangerw
authored andcommitted
Replace Backbone with keras.Model in CLIPTextEncoder and T5XXLTextEncoder (#1802)
1 parent c4627d1 commit b9bc61e

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import keras
1415
from keras import layers
1516
from keras import ops
1617

1718
from keras_nlp.src.layers.modeling.token_and_position_embedding import (
1819
TokenAndPositionEmbedding,
1920
)
20-
from keras_nlp.src.models.backbone import Backbone
2121
from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import (
2222
CLIPEncoderBlock,
2323
)
2424

2525

26-
class CLIPTextEncoder(Backbone):
26+
class CLIPTextEncoder(keras.Model):
2727
def __init__(
2828
self,
2929
embedding_dim,
@@ -108,7 +108,6 @@ def __init__(
108108
super().__init__(
109109
inputs={"encoder_token_ids": encoder_token_ids},
110110
outputs=outputs,
111-
dtype=dtype,
112111
**kwargs,
113112
)
114113

@@ -123,6 +122,15 @@ def __init__(
123122
self.vocabulary_size = vocabulary_size
124123
self.sequence_length = sequence_length
125124

125+
if dtype is not None:
126+
try:
127+
self.dtype_policy = keras.dtype_policies.get(dtype)
128+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
129+
except AttributeError:
130+
if isinstance(dtype, keras.DTypePolicy):
131+
dtype = dtype.name
132+
self.dtype_policy = keras.DTypePolicy(dtype)
133+
126134
def get_config(self):
127135
config = super().get_config()
128136
config.update(

keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
from keras_nlp.src.layers.modeling.reversible_embedding import (
1717
ReversibleEmbedding,
1818
)
19-
from keras_nlp.src.models.backbone import Backbone
2019
from keras_nlp.src.models.t5.t5_layer_norm import T5LayerNorm
2120
from keras_nlp.src.models.t5.t5_transformer_layer import T5TransformerLayer
2221

2322

24-
class T5XXLTextEncoder(Backbone):
23+
class T5XXLTextEncoder(keras.Model):
2524
def __init__(
2625
self,
2726
vocabulary_size,
@@ -111,7 +110,6 @@ def __init__(
111110
"encoder_padding_mask": encoder_padding_mask_input,
112111
},
113112
outputs=encoder_output,
114-
dtype=dtype,
115113
**kwargs,
116114
)
117115

@@ -128,6 +126,15 @@ def __init__(
128126
self.layer_norm_epsilon = layer_norm_epsilon
129127
self.tie_embedding_weights = tie_embedding_weights
130128

129+
if dtype is not None:
130+
try:
131+
self.dtype_policy = keras.dtype_policies.get(dtype)
132+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
133+
except AttributeError:
134+
if isinstance(dtype, keras.DTypePolicy):
135+
dtype = dtype.name
136+
self.dtype_policy = keras.DTypePolicy(dtype)
137+
131138
def get_config(self):
132139
config = super().get_config()
133140
config.update(

0 commit comments

Comments
 (0)