Skip to content

Only allow deserialization of KerasSaveables by module and name. #21429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from keras.src.metrics.metric import Metric
from keras.src.ops.node import Node
from keras.src.ops.operation import Operation
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils import python_utils
from keras.src.utils import summary_utils
from keras.src.utils import traceback_utils
Expand All @@ -67,7 +66,7 @@


@keras_export(["keras.Layer", "keras.layers.Layer"])
class Layer(BackendLayer, Operation, KerasSaveable):
class Layer(BackendLayer, Operation):
"""This is the class from which all layers inherit.

A layer is a callable object that takes as input one or more tensors and
Expand Down
11 changes: 9 additions & 2 deletions keras/src/layers/preprocessing/feature_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
from keras.src.saving import saving_lib
from keras.src.saving import serialization_lib
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils import backend_utils
from keras.src.utils.module_utils import tensorflow as tf
from keras.src.utils.naming import auto_name


class Cross:
class Cross(KerasSaveable):
def __init__(self, feature_names, crossing_dim, output_mode="one_hot"):
if output_mode not in {"int", "one_hot"}:
raise ValueError(
Expand All @@ -23,6 +24,9 @@ def __init__(self, feature_names, crossing_dim, output_mode="one_hot"):
self.crossing_dim = crossing_dim
self.output_mode = output_mode

def _obj_type(self):
return "Cross"

@property
def name(self):
return "_X_".join(self.feature_names)
Expand All @@ -39,7 +43,7 @@ def from_config(cls, config):
return cls(**config)


class Feature:
class Feature(KerasSaveable):
def __init__(self, dtype, preprocessor, output_mode):
if output_mode not in {"int", "one_hot", "float"}:
raise ValueError(
Expand All @@ -55,6 +59,9 @@ def __init__(self, dtype, preprocessor, output_mode):
self.preprocessor = preprocessor
self.output_mode = output_mode

def _obj_type(self):
return "Feature"

def get_config(self):
return {
"dtype": self.dtype,
Expand Down
5 changes: 4 additions & 1 deletion keras/src/legacy/saving/legacy_h5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from absl import logging

from keras.src import backend
from keras.src import optimizers
from keras.src.backend.common import global_state
from keras.src.legacy.saving import json_utils
from keras.src.legacy.saving import saving_options
Expand Down Expand Up @@ -161,6 +160,8 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
# Set optimizer weights.
if "optimizer_weights" in f:
try:
from keras.src import optimizers

if isinstance(model.optimizer, optimizers.Optimizer):
model.optimizer.build(model._trainable_variables)
else:
Expand Down Expand Up @@ -249,6 +250,8 @@ def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
hdf5_group: HDF5 group.
optimizer: optimizer instance.
"""
from keras.src import optimizers

if isinstance(optimizer, optimizers.Optimizer):
symbolic_weights = optimizer.variables
else:
Expand Down
8 changes: 5 additions & 3 deletions keras/src/legacy/saving/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
from absl import logging

from keras.src import backend
from keras.src import layers
from keras.src import losses
from keras.src import metrics as metrics_module
from keras.src import models
from keras.src import optimizers
from keras.src import tree
from keras.src.legacy.saving import serialization
from keras.src.saving import object_registration
Expand Down Expand Up @@ -49,6 +46,9 @@ def model_from_config(config, custom_objects=None):
global MODULE_OBJECTS

if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"):
from keras.src import layers
from keras.src import models

MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__
MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer
MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional
Expand Down Expand Up @@ -132,6 +132,8 @@ def compile_args_from_training_config(training_config, custom_objects=None):
custom_objects = {}

with object_registration.CustomObjectScope(custom_objects):
from keras.src import optimizers

optimizer_config = training_config["optimizer_config"]
optimizer = optimizers.deserialize(optimizer_config)
# Ensure backwards compatibility for optimizers in legacy H5 files
Expand Down
6 changes: 5 additions & 1 deletion keras/src/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from keras.src.api_export import keras_export
from keras.src.backend.common.keras_tensor import any_symbolic_tensors
from keras.src.ops.node import Node
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils import python_utils
from keras.src.utils import traceback_utils
from keras.src.utils.naming import auto_name


@keras_export("keras.Operation")
class Operation:
class Operation(KerasSaveable):
def __init__(self, name=None):
if name is None:
name = auto_name(self.__class__.__name__)
Expand Down Expand Up @@ -311,6 +312,9 @@ def _get_node_attribute_at_index(self, node_index, attr, attr_name):
else:
return values

def _obj_type(self):
return "Operation"

# Hooks for backend layer classes
def _post_build(self):
"""Can be overridden for per backend post build actions."""
Expand Down
35 changes: 29 additions & 6 deletions keras/src/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@

from keras.src import backend
from keras.src.backend.common import global_state
from keras.src.layers.layer import Layer
from keras.src.losses.loss import Loss
from keras.src.metrics.metric import Metric
from keras.src.optimizers.optimizer import Optimizer
from keras.src.saving.serialization_lib import ObjectSharingScope
from keras.src.saving.serialization_lib import deserialize_keras_object
from keras.src.saving.serialization_lib import serialize_keras_object
from keras.src.trainers.compile_utils import CompileMetrics
from keras.src.utils import dtype_utils
from keras.src.utils import file_utils
from keras.src.utils import io_utils
Expand Down Expand Up @@ -1584,32 +1579,60 @@ def get_attr_skipset(obj_type):
"_self_unconditional_dependency_names",
]
)
if obj_type == "Operation":
from keras.src.ops.operation import Operation

ref_obj = Operation()
skipset.update(dir(ref_obj))
if obj_type == "Layer":
from keras.src.layers.layer import Layer

ref_obj = Layer()
skipset.update(dir(ref_obj))
elif obj_type == "Functional":
from keras.src.layers.layer import Layer

ref_obj = Layer()
skipset.update(dir(ref_obj) + ["operations", "_operations"])
elif obj_type == "Sequential":
from keras.src.layers.layer import Layer

ref_obj = Layer()
skipset.update(dir(ref_obj) + ["_functional"])
elif obj_type == "Metric":
from keras.src.metrics.metric import Metric
from keras.src.trainers.compile_utils import CompileMetrics

ref_obj_a = Metric()
ref_obj_b = CompileMetrics([], [])
skipset.update(dir(ref_obj_a) + dir(ref_obj_b))
elif obj_type == "Optimizer":
from keras.src.optimizers.optimizer import Optimizer

ref_obj = Optimizer(1.0)
skipset.update(dir(ref_obj))
skipset.remove("variables")
elif obj_type == "Loss":
from keras.src.losses.loss import Loss

ref_obj = Loss()
skipset.update(dir(ref_obj))
elif obj_type == "Cross":
from keras.src.layers.preprocessing.feature_space import Cross

ref_obj = Cross((), 1)
skipset.update(dir(ref_obj))
elif obj_type == "Feature":
from keras.src.layers.preprocessing.feature_space import Feature

ref_obj = Feature("int32", lambda x: x, "int")
skipset.update(dir(ref_obj))
else:
raise ValueError(
f"get_attr_skipset got invalid {obj_type=}. "
"Accepted values for `obj_type` are "
"['Layer', 'Functional', 'Sequential', 'Metric', "
"'Optimizer', 'Loss']"
"'Optimizer', 'Loss', 'Cross', 'Feature']"
)

global_state.set_global_attribute(
Expand Down
9 changes: 8 additions & 1 deletion keras/src/saving/serialization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
from keras.src.saving import object_registration
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils import python_utils
from keras.src.utils.module_utils import tensorflow as tf

Expand All @@ -32,6 +33,7 @@

LOADING_APIS = frozenset(
{
"keras.config.enable_unsafe_deserialization",
"keras.models.load_model",
"keras.preprocessing.image.load_img",
"keras.saving.load_model",
Expand Down Expand Up @@ -817,8 +819,13 @@ def _retrieve_class_or_fn(
try:
mod = importlib.import_module(module)
obj = vars(mod).get(name, None)
if obj is not None:
if isinstance(obj, type) and issubclass(obj, KerasSaveable):
return obj
else:
raise ValueError(
f"Could not deserialize '{module}.{name}' because "
"it is not a KerasSaveable subclass"
)
except ModuleNotFoundError:
raise TypeError(
f"Could not deserialize {obj_type} '{name}' because "
Expand Down