Skip to content

Commit 503174b

Browse files
committed
Merge branch 'master' of github.com:keras-team/keras
2 parents 0ecc82c + 2ab47a1 commit 503174b

File tree

10 files changed

+114
-39
lines changed

10 files changed

+114
-39
lines changed

.github/workflows/scorecard.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ jobs:
4848
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
4949
# format to the repository Actions tab.
5050
- name: "Upload artifact"
51-
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1
51+
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
5252
with:
5353
name: SARIF file
5454
path: results.sarif
5555
retention-days: 5
5656

5757
# Upload the results to GitHub's code scanning dashboard.
5858
- name: "Upload to code-scanning"
59-
uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9
59+
uses: github/codeql-action/upload-sarif@d39d31e687223d841ef683f52467bd88e9b21c14 # v3.25.3
6060
with:
6161
sarif_file: results.sarif

keras/src/backend/tensorflow/numpy.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,6 +1950,10 @@ def take(x, indices, axis=None):
19501950

19511951

19521952
def take_along_axis(x, indices, axis=None):
1953+
from keras.src.ops.operation_utils import (
1954+
compute_take_along_axis_output_shape,
1955+
)
1956+
19531957
x = convert_to_tensor(x)
19541958
indices = convert_to_tensor(indices, "int64")
19551959
if axis is None:
@@ -1959,7 +1963,13 @@ def take_along_axis(x, indices, axis=None):
19591963
f"Received: indices.shape={indices.shape}"
19601964
)
19611965
return take_along_axis(tf.reshape(x, [-1]), indices, 0)
1962-
rank = tf.rank(x)
1966+
1967+
# Compute the static output shape as later on, all shapes manipulations
1968+
# use dynamic shapes.
1969+
static_output_shape = compute_take_along_axis_output_shape(
1970+
x.shape, indices.shape, axis
1971+
)
1972+
rank = x.ndim
19631973
static_axis = axis
19641974
axis = axis + rank if axis < 0 else axis
19651975

@@ -1981,9 +1991,6 @@ def take_along_axis(x, indices, axis=None):
19811991
x = tf.broadcast_to(x, x_shape)
19821992
indices = tf.broadcast_to(indices, indices_shape)
19831993

1984-
# Save indices shape so we can restore it later.
1985-
possible_result_shape = indices.shape
1986-
19871994
# Correct the indices using "fill" mode which is the same as in jax
19881995
indices = tf.where(indices < 0, indices + x_shape[static_axis], indices)
19891996

@@ -1998,7 +2005,7 @@ def take_along_axis(x, indices, axis=None):
19982005
result = tf.gather(x, indices, batch_dims=1)
19992006
result = tf.reshape(result, indices_shape)
20002007
result = swapaxes(result, static_axis, -1)
2001-
result.set_shape(possible_result_shape)
2008+
result.set_shape(static_output_shape)
20022009
return result
20032010

20042011

keras/src/export/export_lib.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -635,16 +635,31 @@ def _get_save_spec(model):
635635
if not shapes_dict:
636636
return None
637637

638+
def make_tensor_spec(structure):
639+
# We need to turn wrapper structures like TrackingDict or _DictWrapper
640+
# into plain Python structures because they don't work with jax2tf/JAX.
641+
if isinstance(structure, dict):
642+
return {k: make_tensor_spec(v) for k, v in structure.items()}
643+
if isinstance(structure, (list, tuple)):
644+
if all(isinstance(d, (int, type(None))) for d in structure):
645+
return tf.TensorSpec(
646+
shape=(None,) + structure[1:], dtype=model.input_dtype
647+
)
648+
result = [make_tensor_spec(v) for v in structure]
649+
return tuple(result) if isinstance(structure, tuple) else result
650+
else:
651+
raise ValueError(
652+
f"Unsupported type {type(structure)} for {structure}"
653+
)
654+
638655
if len(shapes_dict) == 1:
639-
shape = list(shapes_dict.values())[0]
640-
shape = (None,) + shape[1:]
641-
return tf.TensorSpec(shape=shape, dtype=model.input_dtype)
656+
value = list(shapes_dict.values())[0]
657+
return make_tensor_spec(value)
642658

643659
specs = {}
644-
for key, shape in shapes_dict.items():
660+
for key, value in shapes_dict.items():
645661
key = key.rstrip("_shape")
646-
shape = (None,) + shape[1:]
647-
specs[key] = tf.TensorSpec(shape=shape, dtype=model.input_dtype)
662+
specs[key] = make_tensor_spec(value)
648663

649664
return specs
650665

keras/src/export/export_lib_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,56 @@ def call(self, inputs):
9595
# Test with a different batch size
9696
revived_model.serve(tf.random.normal((6, 10)))
9797

98+
@parameterized.named_parameters(
99+
named_product(struct_type=["tuple", "array", "dict"])
100+
)
101+
def test_model_with_input_structure(self, struct_type):
102+
103+
class TupleModel(models.Model):
104+
105+
def call(self, inputs):
106+
x, y = inputs
107+
return ops.add(x, y)
108+
109+
class ArrayModel(models.Model):
110+
111+
def call(self, inputs):
112+
x = inputs[0]
113+
y = inputs[1]
114+
return ops.add(x, y)
115+
116+
class DictModel(models.Model):
117+
118+
def call(self, inputs):
119+
x = inputs["x"]
120+
y = inputs["y"]
121+
return ops.add(x, y)
122+
123+
if struct_type == "tuple":
124+
model = TupleModel()
125+
ref_input = (tf.random.normal((3, 10)), tf.random.normal((3, 10)))
126+
elif struct_type == "array":
127+
model = ArrayModel()
128+
ref_input = [tf.random.normal((3, 10)), tf.random.normal((3, 10))]
129+
elif struct_type == "dict":
130+
model = DictModel()
131+
ref_input = {
132+
"x": tf.random.normal((3, 10)),
133+
"y": tf.random.normal((3, 10)),
134+
}
135+
136+
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
137+
ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input))
138+
139+
export_lib.export_model(model, temp_filepath)
140+
revived_model = tf.saved_model.load(temp_filepath)
141+
self.assertAllClose(ref_output, revived_model.serve(ref_input))
142+
# Test with a different batch size
143+
bigger_input = tree.map_structure(
144+
lambda x: tf.concat([x, x], axis=0), ref_input
145+
)
146+
revived_model.serve(bigger_input)
147+
98148
@parameterized.named_parameters(
99149
named_product(model_type=["sequential", "functional", "subclass"])
100150
)

keras/src/layers/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,10 +1350,10 @@ def __str__(self):
13501350
def __setattr__(self, name, value):
13511351
# Track Variables, Layers, Metrics, SeedGenerators.
13521352
name, value = self._setattr_hook(name, value)
1353-
if hasattr(self, "_tracker"):
1353+
if name != "_tracker":
1354+
if not hasattr(self, "_tracker"):
1355+
self._initialize_tracker()
13541356
value = self._tracker.track(value)
1355-
elif name != "_tracker":
1356-
self._initialize_tracker()
13571357
return super().__setattr__(name, value)
13581358

13591359
def _check_super_called(self):

keras/src/ops/numpy.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4859,28 +4859,9 @@ def call(self, x, indices):
48594859
return backend.numpy.take_along_axis(x, indices, axis=self.axis)
48604860

48614861
def compute_output_spec(self, x, indices):
4862-
x_shape = list(x.shape)
4863-
indices_shape = list(indices.shape)
4864-
if self.axis is None:
4865-
x_shape = [None] if None in x_shape else [int(np.prod(x_shape))]
4866-
4867-
if len(x_shape) != len(indices_shape):
4868-
raise ValueError(
4869-
"`x` and `indices` must have the same number of dimensions, "
4870-
f"but receive shape {x_shape} and {indices_shape}."
4871-
)
4872-
4873-
del x_shape[self.axis]
4874-
del indices_shape[self.axis]
4875-
output_shape = broadcast_shapes(x_shape, indices_shape)
4876-
size_on_axis = indices.shape[self.axis]
4877-
if self.axis == -1:
4878-
output_shape = output_shape + [size_on_axis]
4879-
elif self.axis >= 0:
4880-
output_shape.insert(self.axis, size_on_axis)
4881-
else:
4882-
output_shape.insert(self.axis + 1, size_on_axis)
4883-
4862+
output_shape = operation_utils.compute_take_along_axis_output_shape(
4863+
x.shape, indices.shape, self.axis
4864+
)
48844865
return KerasTensor(output_shape, dtype=x.dtype)
48854866

48864867

keras/src/ops/operation_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,25 @@ def compute_transpose_output_shape(input_shape, axes):
349349
return tuple(input_shape[ax] for ax in axes)
350350

351351

352+
def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
353+
input_shape = list(input_shape)
354+
indices_shape = list(indices_shape)
355+
if axis is None:
356+
input_shape = (
357+
[None] if None in input_shape else [int(np.prod(input_shape))]
358+
)
359+
360+
if len(input_shape) != len(indices_shape):
361+
raise ValueError(
362+
"`x` and `indices` must have the same number of dimensions, "
363+
f"but receive shape {input_shape} and {indices_shape}."
364+
)
365+
366+
input_shape[axis] = indices_shape[axis]
367+
output_shape = broadcast_shapes(input_shape, indices_shape)
368+
return output_shape
369+
370+
352371
def reduce_shape(shape, axis=None, keepdims=False):
353372
shape = list(shape)
354373
if axis is None:

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ rich
1717
build
1818
optree
1919
pytest-cov
20+
packaging

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Tensorflow.
2-
tensorflow-cpu~=2.16.1 # Pin to TF 2.16
2+
tensorflow-cpu~=2.16.1;sys_platform != 'darwin' # Pin to TF 2.16
3+
tensorflow~=2.16.1;sys_platform == 'darwin'
34

45
# Torch.
56
# TODO: Pin to < 2.3.0 (GitHub issue #19602)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_version(rel_path):
4343
"h5py",
4444
"optree",
4545
"ml-dtypes",
46+
"packaging",
4647
],
4748
# Supported Python versions
4849
python_requires=">=3.9",

0 commit comments

Comments
 (0)