From b058f4b3177c6e3ccc8e656309fcccb9edd07c5d Mon Sep 17 00:00:00 2001 From: Surya2k1 Date: Mon, 9 Dec 2024 21:16:23 +0530 Subject: [PATCH] Fix issue with unsorted dict --- keras/src/models/functional.py | 8 +++++++- keras/src/models/functional_test.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 3bc6901171ee..e01052bc57ec 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -305,7 +305,13 @@ def _standardize_inputs(self, inputs): raise_exception = True else: raise_exception = True - + if ( + isinstance(self._inputs_struct, dict) + and not isinstance(inputs, dict) + and list(self._inputs_struct.keys()) + != sorted(self._inputs_struct.keys()) + ): + raise_exception = True self._maybe_warn_inputs_struct_mismatch( inputs, raise_exception=raise_exception ) diff --git a/keras/src/models/functional_test.py b/keras/src/models/functional_test.py index 1e4585c5653d..cbe3e2c035d4 100644 --- a/keras/src/models/functional_test.py +++ b/keras/src/models/functional_test.py @@ -15,6 +15,7 @@ from keras.src.models import Functional from keras.src.models import Model from keras.src.models import Sequential +from keras.src import ops class FunctionalTest(testing.TestCase): @@ -573,7 +574,6 @@ def is_input_warning(w): with pytest.warns() as warning_logs: model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0) self.assertLen(list(filter(is_input_warning, warning_logs)), 1) - # No warning for mismatched tuples and lists. model = Model([i1, i2], outputs) with warnings.catch_warnings(record=True) as warning_logs: @@ -699,3 +699,17 @@ def test_dict_input_to_list_model(self): "tags": tags_data, } ) + + def test_list_input_with_dict_build(self): + x1 = Input((10,), name="IT") + x2 = Input((10,), name="IS") + y = layers.subtract([x1, x2]) + model = Model(inputs={"IT": x1, "IS": x2}, outputs=y) + x1 = ops.ones((1, 10)) + x2 = ops.zeros((1, 10)) + r1 = model({"IT": x1, "IS": x2}) + with self.assertRaisesRegex( + ValueError, + "The structure of `inputs` doesn't match the expected structure", + ): + r2 = model([x1, x2])