Skip to content

Commit 64a8d4f

Browse files
JyotinderSinghHmm-1224
authored andcommitted
Fixes inconsistent serialization logic for inputs (keras-team#20993)
* Removes unnesting logic for input tensors in functional model deserialization flow * Adds test case for verifying nested input restoration after deserialization removes unnecessary imports * fixes imports
1 parent 414e3f4 commit 64a8d4f

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

keras/src/models/functional.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,8 +618,6 @@ def map_tensors(tensors):
618618

619619
input_tensors = map_tensors(functional_config["input_layers"])
620620
output_tensors = map_tensors(functional_config["output_layers"])
621-
if isinstance(input_tensors, list) and len(input_tensors) == 1:
622-
input_tensors = input_tensors[0]
623621
if isinstance(output_tensors, list) and len(output_tensors) == 1:
624622
output_tensors = output_tensors[0]
625623

keras/src/models/functional_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src.models import Functional
1717
from keras.src.models import Model
1818
from keras.src.models import Sequential
19+
from keras.src.models.model import model_from_json
1920

2021

2122
class FunctionalTest(testing.TestCase):
@@ -272,6 +273,19 @@ def test_restored_multi_output_type(self, out_type):
272273
out_val = model_restored(Input(shape=(3,), batch_size=2))
273274
self.assertIsInstance(out_val, out_type)
274275

276+
def test_restored_nested_input(self):
277+
input_a = Input(shape=(3,), batch_size=2, name="input_a")
278+
x = layers.Dense(5)(input_a)
279+
outputs = layers.Dense(4)(x)
280+
model = Functional([[input_a]], outputs)
281+
282+
# Serialize and deserialize the model
283+
json_config = model.to_json()
284+
restored_json_config = model_from_json(json_config).to_json()
285+
286+
# Check that the serialized model is the same as the original
287+
self.assertEqual(json_config, restored_json_config)
288+
275289
@pytest.mark.requires_trainable_backend
276290
def test_layer_getters(self):
277291
# Test mixing ops and layers

0 commit comments

Comments
 (0)