Skip to content

Commit 199c5c7

Browse files
authored
Do no always make batch size dynamic during export. (#21944)
This is a follow-up of #21674 This PR changed the signature of `make_tf_tensor_spec` from `(x)` to `(x, dynamic_batch=True)`, thereby adding the ability to make the batch size dynamic. This PR also adds `_get_save_spec(self, dynamic_batch=True)` which uses `make_tf_tensor_spec` and forwards the `dynamic_batch` argument. However, the default before this change for other export (SavedModel, ONNX) was to keep the batch size untouched. In particular, when a user manually provides an `input_signature` to [`ExportArchive.add_endpoint`](https://github.com/keras-team/keras/blob/master/keras/src/export/saved_model.py#L362), we should honor. The user controls whether the batch size is dynamic or not in the `input_signature`. This PR changes the default of `make_tf_tensor_spec` back to `dynamic_batch=False` to revert SavedModel and ONNX exports to the previous behavior. Also removed call to `return super()._get_save_spec(dynamic_batch)` which can never succeed as `TFLayer` is a top level class (ignoring the auto-tracking stuff).
1 parent 203b72c commit 199c5c7

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

keras/src/backend/tensorflow/layer.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,18 @@ def _get_save_spec(self, dynamic_batch=True):
9494
A TensorSpec, list or dict mirroring the model inputs, or
9595
`None` when specs cannot be inferred.
9696
"""
97-
# Prefer the base implementation if available
98-
try:
99-
return super()._get_save_spec(dynamic_batch)
100-
except AttributeError:
101-
# Lazy import to avoid circular dependency
102-
from keras.src.export.export_utils import make_tf_tensor_spec
103-
104-
# Fall back to building specs from `self.inputs`
105-
inputs = getattr(self, "inputs", None)
106-
if inputs is None:
107-
return None
108-
109-
return tree.map_structure(
110-
lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),
111-
inputs,
112-
)
97+
# Lazy import to avoid circular dependency
98+
from keras.src.export.export_utils import make_tf_tensor_spec
99+
100+
# Fall back to building specs from `self.inputs`
101+
inputs = getattr(self, "inputs", None)
102+
if inputs is None:
103+
return None
104+
105+
return tree.map_structure(
106+
lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),
107+
inputs,
108+
)
113109

114110
@property
115111
def _default_save_signature(self):

keras/src/export/export_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def make_input_spec(x):
102102
return input_spec
103103

104104

105-
def make_tf_tensor_spec(x, dynamic_batch=True):
105+
def make_tf_tensor_spec(x, dynamic_batch=False):
106106
"""Create a TensorSpec from various input types.
107107
108108
Args:

0 commit comments

Comments
 (0)