|
16 | 16 | from keras.src.models import Functional
|
17 | 17 | from keras.src.models import Model
|
18 | 18 | from keras.src.models import Sequential
|
| 19 | +from keras.src.models.model import model_from_json |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class FunctionalTest(testing.TestCase):
|
@@ -272,6 +273,19 @@ def test_restored_multi_output_type(self, out_type):
|
272 | 273 | out_val = model_restored(Input(shape=(3,), batch_size=2))
|
273 | 274 | self.assertIsInstance(out_val, out_type)
|
274 | 275 |
|
| 276 | + def test_restored_nested_input(self): |
| 277 | + input_a = Input(shape=(3,), batch_size=2, name="input_a") |
| 278 | + x = layers.Dense(5)(input_a) |
| 279 | + outputs = layers.Dense(4)(x) |
| 280 | + model = Functional([[input_a]], outputs) |
| 281 | + |
| 282 | + # Serialize and deserialize the model |
| 283 | + json_config = model.to_json() |
| 284 | + restored_json_config = model_from_json(json_config).to_json() |
| 285 | + |
| 286 | + # Check that the serialized model is the same as the original |
| 287 | + self.assertEqual(json_config, restored_json_config) |
| 288 | + |
275 | 289 | @pytest.mark.requires_trainable_backend
|
276 | 290 | def test_layer_getters(self):
|
277 | 291 | # Test mixing ops and layers
|
|
0 commit comments