Skip to content

Commit 4986616

Browse files
james77777778wang-xianghao
authored andcommitted
Suppress warnings for mismatched tuples and lists in functional models. (keras-team#20456)
1 parent ac204b4 commit 4986616

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

keras/src/models/functional.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def _assert_input_compatibility(self, *args):
214214

215215
def _maybe_warn_inputs_struct_mismatch(self, inputs):
216216
try:
217+
# We first normalize to tuples before performing the check to
218+
# suppress warnings when encountering mismatched tuples and lists.
217219
tree.assert_same_structure(
218-
inputs, self._inputs_struct, check_types=False
220+
tree.lists_to_tuples(inputs),
221+
tree.lists_to_tuples(self._inputs_struct),
222+
check_types=False,
219223
)
220224
except:
221225
model_inputs_struct = tree.map_structure(

keras/src/models/functional_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23

34
import numpy as np
45
import pytest
@@ -503,13 +504,19 @@ def test_warning_for_mismatched_inputs_structure(self):
503504
model = Model({"i1": i1, "i2": i2}, outputs)
504505

505506
with pytest.warns() as record:
506-
model([np.ones((2, 2)), np.zeros((2, 2))])
507+
model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0)
507508
self.assertLen(record, 1)
508509
self.assertStartsWith(
509510
str(record[0].message),
510511
r"The structure of `inputs` doesn't match the expected structure:",
511512
)
512513

514+
# No warning for mismatched tuples and lists.
515+
model = Model([i1, i2], outputs)
516+
with warnings.catch_warnings(record=True) as warning_logs:
517+
model.predict((np.ones((2, 2)), np.zeros((2, 2))), verbose=0)
518+
self.assertLen(warning_logs, 0)
519+
513520
def test_for_functional_in_sequential(self):
514521
# Test for a v3.4.1 regression.
515522
if backend.image_data_format() == "channels_first":

0 commit comments

Comments
 (0)