Skip to content

Commit 00d0f49

Browse files
authored
Fix issue with unsorted dict (#20613)
1 parent c789805 commit 00d0f49

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

keras/src/models/functional.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,13 @@ def _standardize_inputs(self, inputs):
305305
raise_exception = True
306306
else:
307307
raise_exception = True
308-
308+
if (
309+
isinstance(self._inputs_struct, dict)
310+
and not isinstance(inputs, dict)
311+
and list(self._inputs_struct.keys())
312+
!= sorted(self._inputs_struct.keys())
313+
):
314+
raise_exception = True
309315
self._maybe_warn_inputs_struct_mismatch(
310316
inputs, raise_exception=raise_exception
311317
)

keras/src/models/functional_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from keras.src.models import Functional
1616
from keras.src.models import Model
1717
from keras.src.models import Sequential
18+
from keras.src import ops
1819

1920

2021
class FunctionalTest(testing.TestCase):
@@ -573,7 +574,6 @@ def is_input_warning(w):
573574
with pytest.warns() as warning_logs:
574575
model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0)
575576
self.assertLen(list(filter(is_input_warning, warning_logs)), 1)
576-
577577
# No warning for mismatched tuples and lists.
578578
model = Model([i1, i2], outputs)
579579
with warnings.catch_warnings(record=True) as warning_logs:
@@ -699,3 +699,17 @@ def test_dict_input_to_list_model(self):
699699
"tags": tags_data,
700700
}
701701
)
702+
703+
def test_list_input_with_dict_build(self):
704+
x1 = Input((10,), name="IT")
705+
x2 = Input((10,), name="IS")
706+
y = layers.subtract([x1, x2])
707+
model = Model(inputs={"IT": x1, "IS": x2}, outputs=y)
708+
x1 = ops.ones((1, 10))
709+
x2 = ops.zeros((1, 10))
710+
r1 = model({"IT": x1, "IS": x2})
711+
with self.assertRaisesRegex(
712+
ValueError,
713+
"The structure of `inputs` doesn't match the expected structure",
714+
):
715+
r2 = model([x1, x2])

0 commit comments

Comments
 (0)