diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index eaff1a8376a2..4ef338b668a1 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -44,7 +44,6 @@ from keras.src.metrics.metric import Metric from keras.src.ops.node import Node from keras.src.ops.operation import Operation -from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils import summary_utils from keras.src.utils import traceback_utils @@ -67,7 +66,7 @@ @keras_export(["keras.Layer", "keras.layers.Layer"]) -class Layer(BackendLayer, Operation, KerasSaveable): +class Layer(BackendLayer, Operation): """This is the class from which all layers inherit. A layer is a callable object that takes as input one or more tensors and diff --git a/keras/src/layers/preprocessing/feature_space.py b/keras/src/layers/preprocessing/feature_space.py index 5fc5e34afafa..5f219dc1cf1c 100644 --- a/keras/src/layers/preprocessing/feature_space.py +++ b/keras/src/layers/preprocessing/feature_space.py @@ -6,12 +6,13 @@ from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer from keras.src.saving import saving_lib from keras.src.saving import serialization_lib +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import backend_utils from keras.src.utils.module_utils import tensorflow as tf from keras.src.utils.naming import auto_name -class Cross: +class Cross(KerasSaveable): def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): if output_mode not in {"int", "one_hot"}: raise ValueError( @@ -23,6 +24,9 @@ def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): self.crossing_dim = crossing_dim self.output_mode = output_mode + def _obj_type(self): + return "Cross" + @property def name(self): return "_X_".join(self.feature_names) @@ -39,7 +43,7 @@ def from_config(cls, config): return cls(**config) -class Feature: +class Feature(KerasSaveable): def __init__(self, dtype, preprocessor, output_mode): if output_mode not in {"int", "one_hot", "float"}: raise ValueError( @@ -55,6 +59,9 @@ def __init__(self, dtype, preprocessor, output_mode): self.preprocessor = preprocessor self.output_mode = output_mode + def _obj_type(self): + return "Feature" + def get_config(self): return { "dtype": self.dtype, diff --git a/keras/src/legacy/saving/legacy_h5_format.py b/keras/src/legacy/saving/legacy_h5_format.py index d7f3c3eb7ded..5b919f80e7c6 100644 --- a/keras/src/legacy/saving/legacy_h5_format.py +++ b/keras/src/legacy/saving/legacy_h5_format.py @@ -6,7 +6,6 @@ from absl import logging from keras.src import backend -from keras.src import optimizers from keras.src.backend.common import global_state from keras.src.legacy.saving import json_utils from keras.src.legacy.saving import saving_options @@ -161,6 +160,8 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # Set optimizer weights. if "optimizer_weights" in f: try: + from keras.src import optimizers + if isinstance(model.optimizer, optimizers.Optimizer): model.optimizer.build(model._trainable_variables) else: @@ -249,6 +250,8 @@ def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer): hdf5_group: HDF5 group. optimizer: optimizer instance. """ + from keras.src import optimizers + if isinstance(optimizer, optimizers.Optimizer): symbolic_weights = optimizer.variables else: diff --git a/keras/src/legacy/saving/saving_utils.py b/keras/src/legacy/saving/saving_utils.py index 5780ad701163..1373ba11e785 100644 --- a/keras/src/legacy/saving/saving_utils.py +++ b/keras/src/legacy/saving/saving_utils.py @@ -4,11 +4,8 @@ from absl import logging from keras.src import backend -from keras.src import layers from keras.src import losses from keras.src import metrics as metrics_module -from keras.src import models -from keras.src import optimizers from keras.src import tree from keras.src.legacy.saving import serialization from keras.src.saving import object_registration @@ -49,6 +46,9 @@ def model_from_config(config, custom_objects=None): global MODULE_OBJECTS if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"): + from keras.src import layers + from keras.src import models + MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__ MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional @@ -132,6 +132,8 @@ def compile_args_from_training_config(training_config, custom_objects=None): custom_objects = {} with object_registration.CustomObjectScope(custom_objects): + from keras.src import optimizers + optimizer_config = training_config["optimizer_config"] optimizer = optimizers.deserialize(optimizer_config) # Ensure backwards compatibility for optimizers in legacy H5 files diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 9529a8e689f1..5813593340e3 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -7,13 +7,14 @@ from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors from keras.src.ops.node import Node +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils import traceback_utils from keras.src.utils.naming import auto_name @keras_export("keras.Operation") -class Operation: +class Operation(KerasSaveable): def __init__(self, name=None): if name is None: name = auto_name(self.__class__.__name__) @@ -311,6 +312,9 @@ def _get_node_attribute_at_index(self, node_index, attr, attr_name): else: return values + def _obj_type(self): + return "Operation" + # Hooks for backend layer classes def _post_build(self): """Can be overridden for per backend post build actions.""" diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index 3d19e81ddec6..01d0b0bbb031 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -16,14 +16,9 @@ from keras.src import backend from keras.src.backend.common import global_state -from keras.src.layers.layer import Layer -from keras.src.losses.loss import Loss -from keras.src.metrics.metric import Metric -from keras.src.optimizers.optimizer import Optimizer from keras.src.saving.serialization_lib import ObjectSharingScope from keras.src.saving.serialization_lib import deserialize_keras_object from keras.src.saving.serialization_lib import serialize_keras_object -from keras.src.trainers.compile_utils import CompileMetrics from keras.src.utils import dtype_utils from keras.src.utils import file_utils from keras.src.utils import io_utils @@ -1584,32 +1579,60 @@ def get_attr_skipset(obj_type): "_self_unconditional_dependency_names", ] ) + if obj_type == "Operation": + from keras.src.ops.operation import Operation + + ref_obj = Operation() + skipset.update(dir(ref_obj)) if obj_type == "Layer": + from keras.src.layers.layer import Layer + ref_obj = Layer() skipset.update(dir(ref_obj)) elif obj_type == "Functional": + from keras.src.layers.layer import Layer + ref_obj = Layer() skipset.update(dir(ref_obj) + ["operations", "_operations"]) elif obj_type == "Sequential": + from keras.src.layers.layer import Layer + ref_obj = Layer() skipset.update(dir(ref_obj) + ["_functional"]) elif obj_type == "Metric": + from keras.src.metrics.metric import Metric + from keras.src.trainers.compile_utils import CompileMetrics + ref_obj_a = Metric() ref_obj_b = CompileMetrics([], []) skipset.update(dir(ref_obj_a) + dir(ref_obj_b)) elif obj_type == "Optimizer": + from keras.src.optimizers.optimizer import Optimizer + ref_obj = Optimizer(1.0) skipset.update(dir(ref_obj)) skipset.remove("variables") elif obj_type == "Loss": + from keras.src.losses.loss import Loss + ref_obj = Loss() skipset.update(dir(ref_obj)) + elif obj_type == "Cross": + from keras.src.layers.preprocessing.feature_space import Cross + + ref_obj = Cross((), 1) + skipset.update(dir(ref_obj)) + elif obj_type == "Feature": + from keras.src.layers.preprocessing.feature_space import Feature + + ref_obj = Feature("int32", lambda x: x, "int") + skipset.update(dir(ref_obj)) else: raise ValueError( f"get_attr_skipset got invalid {obj_type=}. " "Accepted values for `obj_type` are " "['Layer', 'Functional', 'Sequential', 'Metric', " - "'Optimizer', 'Loss']" + "'Optimizer', 'Loss', 'Cross', 'Feature']" ) global_state.set_global_attribute( diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index 53b88f389407..180176698d76 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -12,6 +12,7 @@ from keras.src.api_export import keras_export from keras.src.backend.common import global_state from keras.src.saving import object_registration +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils.module_utils import tensorflow as tf @@ -32,6 +33,7 @@ LOADING_APIS = frozenset( { + "keras.config.enable_unsafe_deserialization", "keras.models.load_model", "keras.preprocessing.image.load_img", "keras.saving.load_model", @@ -817,8 +819,13 @@ def _retrieve_class_or_fn( try: mod = importlib.import_module(module) obj = vars(mod).get(name, None) - if obj is not None: + if isinstance(obj, type) and issubclass(obj, KerasSaveable): return obj + else: + raise ValueError( + f"Could not deserialize '{module}.{name}' because " + "it is not a KerasSaveable subclass" + ) except ModuleNotFoundError: raise TypeError( f"Could not deserialize {obj_type} '{name}' because "