File tree Expand file tree Collapse file tree 2 files changed +21
-6
lines changed
keras_nlp/src/models/stable_diffusion_v3 Expand file tree Collapse file tree 2 files changed +21
-6
lines changed Original file line number Diff line number Diff line change 11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import keras
14
15
from keras import layers
15
16
from keras import ops
16
17
17
18
from keras_nlp .src .layers .modeling .token_and_position_embedding import (
18
19
TokenAndPositionEmbedding ,
19
20
)
20
- from keras_nlp .src .models .backbone import Backbone
21
21
from keras_nlp .src .models .stable_diffusion_v3 .clip_encoder_block import (
22
22
CLIPEncoderBlock ,
23
23
)
24
24
25
25
26
- class CLIPTextEncoder (Backbone ):
26
+ class CLIPTextEncoder (keras . Model ):
27
27
def __init__ (
28
28
self ,
29
29
embedding_dim ,
@@ -108,7 +108,6 @@ def __init__(
108
108
super ().__init__ (
109
109
inputs = {"encoder_token_ids" : encoder_token_ids },
110
110
outputs = outputs ,
111
- dtype = dtype ,
112
111
** kwargs ,
113
112
)
114
113
@@ -123,6 +122,15 @@ def __init__(
123
122
self .vocabulary_size = vocabulary_size
124
123
self .sequence_length = sequence_length
125
124
125
+ if dtype is not None :
126
+ try :
127
+ self .dtype_policy = keras .dtype_policies .get (dtype )
128
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
129
+ except AttributeError :
130
+ if isinstance (dtype , keras .DTypePolicy ):
131
+ dtype = dtype .name
132
+ self .dtype_policy = keras .DTypePolicy (dtype )
133
+
126
134
def get_config (self ):
127
135
config = super ().get_config ()
128
136
config .update (
Original file line number Diff line number Diff line change 16
16
from keras_nlp .src .layers .modeling .reversible_embedding import (
17
17
ReversibleEmbedding ,
18
18
)
19
- from keras_nlp .src .models .backbone import Backbone
20
19
from keras_nlp .src .models .t5 .t5_layer_norm import T5LayerNorm
21
20
from keras_nlp .src .models .t5 .t5_transformer_layer import T5TransformerLayer
22
21
23
22
24
- class T5XXLTextEncoder (Backbone ):
23
+ class T5XXLTextEncoder (keras . Model ):
25
24
def __init__ (
26
25
self ,
27
26
vocabulary_size ,
@@ -111,7 +110,6 @@ def __init__(
111
110
"encoder_padding_mask" : encoder_padding_mask_input ,
112
111
},
113
112
outputs = encoder_output ,
114
- dtype = dtype ,
115
113
** kwargs ,
116
114
)
117
115
@@ -128,6 +126,15 @@ def __init__(
128
126
self .layer_norm_epsilon = layer_norm_epsilon
129
127
self .tie_embedding_weights = tie_embedding_weights
130
128
129
+ if dtype is not None :
130
+ try :
131
+ self .dtype_policy = keras .dtype_policies .get (dtype )
132
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
133
+ except AttributeError :
134
+ if isinstance (dtype , keras .DTypePolicy ):
135
+ dtype = dtype .name
136
+ self .dtype_policy = keras .DTypePolicy (dtype )
137
+
131
138
def get_config (self ):
132
139
config = super ().get_config ()
133
140
config .update (
You can’t perform that action at this time.
0 commit comments