Skip to content

Commit 0a6fa79

Browse files
Fix Segment Anything Model saving bug (#2138)
* Fix Segment Anything Model saving bug * Use keras iitializers for Keras2/3/ compatibility
1 parent 26bcee8 commit 0a6fa79

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

keras_cv/models/segmentation/segment_anything/sam_layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,15 +275,12 @@ def __init__(self, num_positional_features, scale, **kwargs):
275275
super().__init__(**kwargs)
276276
self.num_positional_features = num_positional_features
277277
self.scale = scale
278-
init_func = lambda *a, **kw: self.scale * ops.random.normal(
279-
shape=(2, self.num_positional_features), dtype=self.dtype
280-
)
281278
self.positional_encoding_gaussian_matrix = self.add_weight(
282279
name="positional_encoding_gaussian_matrix",
283280
shape=(2, self.num_positional_features),
284281
dtype=self.dtype,
285282
trainable=False,
286-
initializer=init_func,
283+
initializer=keras.initializers.get("normal"),
287284
)
288285

289286
def build(self, input_shape=None):

0 commit comments

Comments
 (0)