"
+
+ def __iter__(self):
+ keys = sorted(self._config.keys())
+ for k in keys:
+ yield k
+
+ def __len__(self):
+ return len(self._config)
+
+ def __delitem__(self, key):
+ self._raise_if_frozen()
+ del self._config[key]
+
+ def __contains__(self, item):
+ return item in self._config
diff --git a/keras/src/utils/dataset_utils.py b/keras/src/utils/dataset_utils.py
index ff27552278f4..f1b6b7e02f10 100644
--- a/keras/src/utils/dataset_utils.py
+++ b/keras/src/utils/dataset_utils.py
@@ -6,6 +6,7 @@
import numpy as np
+from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.utils import io_utils
from keras.src.utils.module_utils import tensorflow as tf
@@ -137,16 +138,7 @@ def _convert_dataset_to_list(
data_size_warning_flag,
start_time,
):
- if dataset_type_spec in [tuple, list]:
- # The try-except here is for NumPy 1.24 compatibility, see:
- # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
- try:
- arr = np.array(sample)
- except ValueError:
- arr = np.array(sample, dtype=object)
- dataset_as_list.append(arr)
- else:
- dataset_as_list.append(sample)
+ dataset_as_list.append(sample)
return dataset_as_list
@@ -162,66 +154,66 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
Returns:
iterator: An `iterator` object.
"""
- if dataset_type_spec == list:
+ if dataset_type_spec is list:
if len(dataset) == 0:
raise ValueError(
"Received an empty list dataset. "
"Please provide a non-empty list of arrays."
)
- if _get_type_spec(dataset[0]) is np.ndarray:
- expected_shape = dataset[0].shape
- for i, element in enumerate(dataset):
- if np.array(element).shape[0] != expected_shape[0]:
- raise ValueError(
- "Received a list of NumPy arrays with different "
- f"lengths. Mismatch found at index {i}, "
- f"Expected shape={expected_shape} "
- f"Received shape={np.array(element).shape}."
- "Please provide a list of NumPy arrays with "
- "the same length."
- )
- else:
- raise ValueError(
- "Expected a list of `numpy.ndarray` objects,"
- f"Received: {type(dataset[0])}"
- )
+ expected_shape = None
+ for i, element in enumerate(dataset):
+ if not isinstance(element, np.ndarray):
+ raise ValueError(
+ "Expected a list of `numpy.ndarray` objects,"
+ f"Received: {type(element)} at index {i}."
+ )
+ if expected_shape is None:
+ expected_shape = element.shape
+ elif element.shape[0] != expected_shape[0]:
+ raise ValueError(
+ "Received a list of NumPy arrays with different lengths."
+ f"Mismatch found at index {i}, "
+ f"Expected shape={expected_shape} "
+ f"Received shape={np.array(element).shape}."
+ "Please provide a list of NumPy arrays of the same length."
+ )
return iter(zip(*dataset))
- elif dataset_type_spec == tuple:
+ elif dataset_type_spec is tuple:
if len(dataset) == 0:
raise ValueError(
"Received an empty list dataset."
"Please provide a non-empty tuple of arrays."
)
- if _get_type_spec(dataset[0]) is np.ndarray:
- expected_shape = dataset[0].shape
- for i, element in enumerate(dataset):
- if np.array(element).shape[0] != expected_shape[0]:
- raise ValueError(
- "Received a tuple of NumPy arrays with different "
- f"lengths. Mismatch found at index {i}, "
- f"Expected shape={expected_shape} "
- f"Received shape={np.array(element).shape}."
- "Please provide a tuple of NumPy arrays with "
- "the same length."
- )
- else:
- raise ValueError(
- "Expected a tuple of `numpy.ndarray` objects, "
- f"Received: {type(dataset[0])}"
- )
+ expected_shape = None
+ for i, element in enumerate(dataset):
+ if not isinstance(element, np.ndarray):
+ raise ValueError(
+ "Expected a tuple of `numpy.ndarray` objects,"
+ f"Received: {type(element)} at index {i}."
+ )
+ if expected_shape is None:
+ expected_shape = element.shape
+ elif element.shape[0] != expected_shape[0]:
+ raise ValueError(
+ "Received a tuple of NumPy arrays with different lengths."
+ f"Mismatch found at index {i}, "
+ f"Expected shape={expected_shape} "
+ f"Received shape={np.array(element).shape}."
+ "Please provide a tuple of NumPy arrays of the same length."
+ )
return iter(zip(*dataset))
- elif dataset_type_spec == tf.data.Dataset:
+ elif dataset_type_spec is tf.data.Dataset:
if is_batched(dataset):
dataset = dataset.unbatch()
return iter(dataset)
elif is_torch_dataset(dataset):
return iter(dataset)
- elif dataset_type_spec == np.ndarray:
+ elif dataset_type_spec is np.ndarray:
return iter(dataset)
raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}")
@@ -358,9 +350,9 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
# check left_size is non-negative and less than 1 and less than total_length
if (
- left_size_type == int
+ left_size_type is int
and (left_size <= 0 or left_size >= total_length)
- or left_size_type == float
+ or left_size_type is float
and (left_size <= 0 or left_size >= 1)
):
raise ValueError(
@@ -373,9 +365,9 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
# check right_size is non-negative and less than 1 and less than
# total_length
if (
- right_size_type == int
+ right_size_type is int
and (right_size <= 0 or right_size >= total_length)
- or right_size_type == float
+ or right_size_type is float
and (right_size <= 0 or right_size >= 1)
):
raise ValueError(
@@ -388,7 +380,7 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
# check sum of left_size and right_size is less than or equal to
# total_length
if (
- right_size_type == left_size_type == float
+ right_size_type is left_size_type is float
and right_size + left_size > 1
):
raise ValueError(
@@ -396,14 +388,14 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
"than 1. It must be less than or equal to 1."
)
- if left_size_type == float:
+ if left_size_type is float:
left_size = round(left_size * total_length)
- elif left_size_type == int:
+ elif left_size_type is int:
left_size = float(left_size)
- if right_size_type == float:
+ if right_size_type is float:
right_size = round(right_size * total_length)
- elif right_size_type == int:
+ elif right_size_type is int:
right_size = float(right_size)
if left_size is None:
@@ -415,7 +407,7 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
raise ValueError(
"The sum of `left_size` and `right_size` should "
"be smaller than the {total_length}. "
- f"Received: left_size + right_size = {left_size+right_size}"
+ f"Received: left_size + right_size = {left_size + right_size}"
f"and total_length = {total_length}"
)
@@ -436,23 +428,24 @@ def _restore_dataset_from_list(
dataset_as_list, dataset_type_spec, original_dataset
):
"""Restore the dataset from the list of arrays."""
- if dataset_type_spec in [tuple, list]:
- return tuple(np.array(sample) for sample in zip(*dataset_as_list))
- elif dataset_type_spec == tf.data.Dataset:
- if isinstance(original_dataset.element_spec, dict):
- restored_dataset = {}
- for d in dataset_as_list:
- for k, v in d.items():
- if k not in restored_dataset:
- restored_dataset[k] = [v]
- else:
- restored_dataset[k].append(v)
- return restored_dataset
- else:
- return tuple(np.array(sample) for sample in zip(*dataset_as_list))
+ if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset(
+ original_dataset
+ ):
+ # Save structure by taking the first element.
+ element_spec = dataset_as_list[0]
+ # Flatten each element.
+ dataset_as_list = [tree.flatten(sample) for sample in dataset_as_list]
+ # Combine respective elements at all indices.
+ dataset_as_list = [np.array(sample) for sample in zip(*dataset_as_list)]
+ # Recreate the original structure of elements.
+ dataset_as_list = tree.pack_sequence_as(element_spec, dataset_as_list)
+ # Turn lists to tuples as tf.data will fail on lists.
+ return tree.traverse(
+ lambda x: tuple(x) if isinstance(x, list) else x,
+ dataset_as_list,
+ top_down=False,
+ )
- elif is_torch_dataset(original_dataset):
- return tuple(np.array(sample) for sample in zip(*dataset_as_list))
return dataset_as_list
@@ -477,14 +470,12 @@ def _get_type_spec(dataset):
return list
elif isinstance(dataset, np.ndarray):
return np.ndarray
- elif isinstance(dataset, dict):
- return dict
elif isinstance(dataset, tf.data.Dataset):
return tf.data.Dataset
elif is_torch_dataset(dataset):
- from torch.utils.data import Dataset as torchDataset
+ from torch.utils.data import Dataset as TorchDataset
- return torchDataset
+ return TorchDataset
else:
return None
@@ -672,7 +663,7 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
def get_training_or_validation_split(samples, labels, validation_split, subset):
- """Potentially restict samples & labels to a training or validation split.
+ """Potentially restrict samples & labels to a training or validation split.
Args:
samples: List of elements.
@@ -691,7 +682,7 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
num_val_samples = int(validation_split * len(samples))
if subset == "training":
io_utils.print_msg(
- f"Using {len(samples) - num_val_samples} " f"files for training."
+ f"Using {len(samples) - num_val_samples} files for training."
)
samples = samples[:-num_val_samples]
if labels is not None:
diff --git a/keras/src/utils/dataset_utils_test.py b/keras/src/utils/dataset_utils_test.py
index 11cb275ff815..c907736cc0ba 100644
--- a/keras/src/utils/dataset_utils_test.py
+++ b/keras/src/utils/dataset_utils_test.py
@@ -1,127 +1,49 @@
+import collections
+import itertools
+
import numpy as np
+from absl.testing import parameterized
+from torch.utils.data import Dataset as TorchDataset
from keras.src.testing import test_case
+from keras.src.testing.test_utils import named_product
from keras.src.utils.dataset_utils import split_dataset
from keras.src.utils.module_utils import tensorflow as tf
-class DatasetUtilsTest(test_case.TestCase):
- def test_split_dataset_list(self):
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = [
- np.random.sample((n_sample, n_cols)),
- np.random.sample((n_sample, n_pred)),
- ]
- dataset_left, dataset_right = split_dataset(
- dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (n_cols)
- )
+class MyTorchDataset(TorchDataset):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = [
- np.random.sample((n_sample, 100, n_cols)),
- np.random.sample((n_sample, n_pred)),
- ]
- dataset_left, dataset_right = split_dataset(
- dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (100, n_cols)
- )
+ def __len__(self):
+ return len(self.x)
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = [
- np.random.sample((n_sample, 10, 10, n_cols)),
- np.random.sample((n_sample, n_pred)),
- ]
- dataset_left, dataset_right = split_dataset(
- dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (10, 10, n_cols)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = [
- np.random.sample((n_sample, 100, 10, 30, n_cols)),
- np.random.sample((n_sample, n_pred)),
- ]
- dataset_left, dataset_right = split_dataset(
- dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape,
- (100, 10, 30, n_cols),
- )
+ def __getitem__(self, index):
+ return self.x[index], self.y[index]
- def test_split_dataset_tuple(self):
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = (
- np.random.sample((n_sample, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- dataset_left, dataset_right = split_dataset(
- dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (n_cols)
- )
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = (
- np.random.sample((n_sample, 100, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- dataset_left, dataset_right = split_dataset(
- dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (100, n_cols)
- )
+class DatasetUtilsTest(test_case.TestCase):
+ @parameterized.named_parameters(
+ named_product(
+ dataset_type=["list", "tuple", "tensorflow", "torch"],
+ features_shape=[(2,), (100, 2), (10, 10, 2)],
+ )
+ )
+ def test_split_dataset(self, dataset_type, features_shape):
+ n_sample, left_size, right_size = 100, 0.2, 0.8
+ features = np.random.sample((n_sample,) + features_shape)
+ labels = np.random.sample((n_sample, 1))
+
+ if dataset_type == "list":
+ dataset = [features, labels]
+ elif dataset_type == "tuple":
+ dataset = (features, labels)
+ elif dataset_type == "tensorflow":
+ dataset = tf.data.Dataset.from_tensor_slices((features, labels))
+ elif dataset_type == "torch":
+ dataset = MyTorchDataset(features, labels)
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = (
- np.random.sample((n_sample, 10, 10, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
dataset_left, dataset_right = split_dataset(
dataset, left_size=left_size, right_size=right_size
)
@@ -131,15 +53,34 @@ def test_split_dataset_tuple(self):
self.assertEqual(
int(dataset_right.cardinality()), int(n_sample * right_size)
)
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (10, 10, n_cols)
- )
+ for sample in itertools.chain(dataset_left, dataset_right):
+ self.assertEqual(sample[0].shape, features_shape)
+ self.assertEqual(sample[1].shape, (1,))
+
+ @parameterized.named_parameters(
+ named_product(structure_type=["tuple", "dict", "OrderedDict"])
+ )
+ def test_split_dataset_nested_structures(self, structure_type):
+ n_sample, left_size, right_size = 100, 0.2, 0.8
+ features1 = np.random.sample((n_sample, 2))
+ features2 = np.random.sample((n_sample, 10, 2))
+ labels = np.random.sample((n_sample, 1))
+
+ if structure_type == "tuple":
+ dataset = tf.data.Dataset.from_tensor_slices(
+ ((features1, features2), labels)
+ )
+ if structure_type == "dict":
+ dataset = tf.data.Dataset.from_tensor_slices(
+ {"y": features2, "x": features1, "labels": labels}
+ )
+ if structure_type == "OrderedDict":
+ dataset = tf.data.Dataset.from_tensor_slices(
+ collections.OrderedDict(
+ [("y", features2), ("x", features1), ("labels", labels)]
+ )
+ )
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- dataset = (
- np.random.sample((n_sample, 100, 10, 30, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
dataset_left, dataset_right = split_dataset(
dataset, left_size=left_size, right_size=right_size
)
@@ -149,186 +90,11 @@ def test_split_dataset_tuple(self):
self.assertEqual(
int(dataset_right.cardinality()), int(n_sample * right_size)
)
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape,
- (100, 10, 30, n_cols),
- )
-
- def test_split_dataset_tensorflow(self):
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
- dataset_left, dataset_right = split_dataset(
- tf_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (n_cols)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, 100, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
- dataset_left, dataset_right = split_dataset(
- tf_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (100, n_cols)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, 10, 10, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
- dataset_left, dataset_right = split_dataset(
- tf_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (10, 10, n_cols)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, 100, 10, 30, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
- dataset_left, dataset_right = split_dataset(
- tf_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- int(dataset_left.cardinality()), int(n_sample * left_size)
- )
- self.assertEqual(
- int(dataset_right.cardinality()), int(n_sample * right_size)
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape,
- (100, 10, 30, n_cols),
- )
-
- def test_split_dataset_torch(self):
- # sample torch dataset class
- from torch.utils.data import Dataset as torchDataset
-
- class Dataset(torchDataset):
- "Characterizes a dataset for PyTorch"
-
- def __init__(self, x, y):
- "Initialization"
- self.x = x
- self.y = y
-
- def __len__(self):
- "Denotes the total number of samples"
- return len(self.x)
-
- def __getitem__(self, index):
- "Generates one sample of data"
- return self.x[index], self.y[index]
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- torch_dataset = Dataset(features, labels)
- dataset_left, dataset_right = split_dataset(
- torch_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- len([sample for sample in dataset_left]), int(n_sample * left_size)
- )
- self.assertEqual(
- len([sample for sample in dataset_right]),
- int(n_sample * right_size),
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (n_cols,)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, 100, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- torch_dataset = Dataset(features, labels)
- dataset_left, dataset_right = split_dataset(
- torch_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- len([sample for sample in dataset_left]), int(n_sample * left_size)
- )
- self.assertEqual(
- len([sample for sample in dataset_right]),
- int(n_sample * right_size),
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (100, n_cols)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, 10, 10, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- torch_dataset = Dataset(features, labels)
- dataset_left, dataset_right = split_dataset(
- torch_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- len([sample for sample in dataset_left]), int(n_sample * left_size)
- )
- self.assertEqual(
- len([sample for sample in dataset_right]),
- int(n_sample * right_size),
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape, (10, 10, n_cols)
- )
-
- n_sample, n_cols, n_pred, left_size, right_size = 100, 2, 1, 0.2, 0.8
- features, labels = (
- np.random.sample((n_sample, 100, 10, 30, n_cols)),
- np.random.sample((n_sample, n_pred)),
- )
- torch_dataset = Dataset(features, labels)
- dataset_left, dataset_right = split_dataset(
- torch_dataset, left_size=left_size, right_size=right_size
- )
- self.assertEqual(
- len([sample for sample in dataset_left]), int(n_sample * left_size)
- )
- self.assertEqual(
- len([sample for sample in dataset_right]),
- int(n_sample * right_size),
- )
- self.assertEqual(
- [sample for sample in dataset_right][0][0].shape,
- (100, 10, 30, n_cols),
- )
+ for sample in itertools.chain(dataset_left, dataset_right):
+ if structure_type in ("dict", "OrderedDict"):
+ x, y, labels = sample["x"], sample["y"], sample["labels"]
+ elif structure_type == "tuple":
+ (x, y), labels = sample
+ self.assertEqual(x.shape, (2,))
+ self.assertEqual(y.shape, (10, 2))
+ self.assertEqual(labels.shape, (1,))
diff --git a/keras/src/utils/file_utils.py b/keras/src/utils/file_utils.py
index e625a9f131ce..7725c19c7cca 100644
--- a/keras/src/utils/file_utils.py
+++ b/keras/src/utils/file_utils.py
@@ -1,6 +1,5 @@
import hashlib
import os
-import pathlib
import re
import shutil
import tarfile
@@ -101,9 +100,11 @@ def extract_archive(file_path, path=".", archive_format="auto"):
if archive_type == "tar":
open_fn = tarfile.open
is_match_fn = tarfile.is_tarfile
- if archive_type == "zip":
+ elif archive_type == "zip":
open_fn = zipfile.ZipFile
is_match_fn = zipfile.is_zipfile
+ else:
+ raise NotImplementedError(archive_type)
if is_match_fn(file_path):
with open_fn(file_path) as archive:
@@ -163,14 +164,18 @@ def get_file(
```
Args:
- fname: Name of the file. If an absolute path, e.g. `"/path/to/file.txt"`
- is specified, the file will be saved at that location.
+ fname: If the target is a single file, this is your desired
+ local name for the file.
If `None`, the name of the file at `origin` will be used.
+ If downloading and extracting a directory archive,
+ the provided `fname` will be used as extraction directory
+ name (only if it doesn't have an extension).
origin: Original URL of the file.
untar: Deprecated in favor of `extract` argument.
- boolean, whether the file should be decompressed
+ Boolean, whether the file is a tar archive that should
+ be extracted.
md5_hash: Deprecated in favor of `file_hash` argument.
- md5 hash of the file for verification
+ md5 hash of the file for file integrity verification.
file_hash: The expected hash string of the file after download.
The sha256 and md5 hash algorithms are both supported.
cache_subdir: Subdirectory under the Keras cache dir where the file is
@@ -179,7 +184,8 @@ def get_file(
hash_algorithm: Select the hash algorithm to verify the file.
options are `"md5'`, `"sha256'`, and `"auto'`.
The default 'auto' detects the hash algorithm in use.
- extract: True tries extracting the file as an Archive, like tar or zip.
+ extract: If `True`, extracts the archive. Only applicable to compressed
+ archive files like tar or zip.
archive_format: Archive format to try for extracting the file.
Options are `"auto'`, `"tar'`, `"zip'`, and `None`.
`"tar"` includes tar, tar.gz, and tar.bz files.
@@ -219,36 +225,50 @@ def get_file(
datadir = os.path.join(datadir_base, cache_subdir)
os.makedirs(datadir, exist_ok=True)
+ provided_fname = fname
fname = path_to_string(fname)
+
if not fname:
fname = os.path.basename(urllib.parse.urlsplit(origin).path)
if not fname:
raise ValueError(
"Can't parse the file name from the origin provided: "
f"'{origin}'."
- "Please specify the `fname` as the input param."
+ "Please specify the `fname` argument."
+ )
+ else:
+ if os.sep in fname:
+ raise ValueError(
+ "Paths are no longer accepted as the `fname` argument. "
+ "To specify the file's parent directory, use "
+ f"the `cache_dir` argument. Received: fname={fname}"
)
- if untar:
- if fname.endswith(".tar.gz"):
- fname = pathlib.Path(fname)
- # The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
- # considers it as 2 suffixes.
- fname = fname.with_suffix("").with_suffix("")
- fname = str(fname)
- untar_fpath = os.path.join(datadir, fname)
- fpath = untar_fpath + ".tar.gz"
+ if extract or untar:
+ if provided_fname:
+ if "." in fname:
+ download_target = os.path.join(datadir, fname)
+ fname = fname[: fname.find(".")]
+ extraction_dir = os.path.join(datadir, fname + "_extracted")
+ else:
+ extraction_dir = os.path.join(datadir, fname)
+ download_target = os.path.join(datadir, fname + "_archive")
+ else:
+ extraction_dir = os.path.join(datadir, fname)
+ download_target = os.path.join(datadir, fname + "_archive")
else:
- fpath = os.path.join(datadir, fname)
+ download_target = os.path.join(datadir, fname)
if force_download:
download = True
- elif os.path.exists(fpath):
+ elif os.path.exists(download_target):
# File found in cache.
download = False
# Verify integrity if a hash was provided.
if file_hash is not None:
- if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
+ if not validate_file(
+ download_target, file_hash, algorithm=hash_algorithm
+ ):
io_utils.print_msg(
"A local file was found, but it seems to be "
f"incomplete or outdated because the {hash_algorithm} "
@@ -270,9 +290,9 @@ def __init__(self):
self.finished = False
def __call__(self, block_num, block_size, total_size):
+ if total_size == -1:
+ total_size = None
if not self.progbar:
- if total_size == -1:
- total_size = None
self.progbar = Progbar(total_size)
current = block_num * block_size
@@ -288,21 +308,23 @@ def __call__(self, block_num, block_size, total_size):
error_msg = "URL fetch failure on {}: {} -- {}"
try:
try:
- urlretrieve(origin, fpath, DLProgbar())
+ urlretrieve(origin, download_target, DLProgbar())
except urllib.error.HTTPError as e:
raise Exception(error_msg.format(origin, e.code, e.msg))
except urllib.error.URLError as e:
raise Exception(error_msg.format(origin, e.errno, e.reason))
except (Exception, KeyboardInterrupt):
- if os.path.exists(fpath):
- os.remove(fpath)
+ if os.path.exists(download_target):
+ os.remove(download_target)
raise
# Validate download if succeeded and user provided an expected hash
# Security conscious users would get the hash of the file from a
# separate channel and pass it to this API to prevent MITM / corruption:
- if os.path.exists(fpath) and file_hash is not None:
- if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
+ if os.path.exists(download_target) and file_hash is not None:
+ if not validate_file(
+ download_target, file_hash, algorithm=hash_algorithm
+ ):
raise ValueError(
"Incomplete or corrupted file detected. "
f"The {hash_algorithm} "
@@ -310,21 +332,18 @@ def __call__(self, block_num, block_size, total_size):
f"of {file_hash}."
)
- if untar:
- if not os.path.exists(untar_fpath):
- status = extract_archive(fpath, datadir, archive_format="tar")
- if not status:
- warnings.warn("Could not extract archive.", stacklevel=2)
- return untar_fpath
+ if extract or untar:
+ if untar:
+ archive_format = "tar"
- if extract:
- status = extract_archive(fpath, datadir, archive_format)
+ status = extract_archive(
+ download_target, extraction_dir, archive_format
+ )
if not status:
warnings.warn("Could not extract archive.", stacklevel=2)
+ return extraction_dir
- # TODO: return extracted fpath if we extracted an archive,
- # rather than the archive path.
- return fpath
+ return download_target
def resolve_hasher(algorithm, file_hash=None):
@@ -403,7 +422,7 @@ def is_remote_path(filepath):
Returns:
bool: True if the filepath is a recognized remote path, otherwise False
"""
- if re.match(r"^(/cns|/cfs|/gcs|/hdfs|.*://).*$", str(filepath)):
+ if re.match(r"^(/cns|/cfs|/gcs|/hdfs|/readahead|.*://).*$", str(filepath)):
return True
return False
@@ -454,6 +473,15 @@ def isdir(path):
return os.path.isdir(path)
+def remove(path):
+ if is_remote_path(path):
+ if gfile.available:
+ return gfile.remove(path)
+ else:
+ _raise_if_no_gfile(path)
+ return os.remove(path)
+
+
def rmtree(path):
if is_remote_path(path):
if gfile.available:
diff --git a/keras/src/utils/file_utils_test.py b/keras/src/utils/file_utils_test.py
index c09f47acd1aa..428370a67041 100644
--- a/keras/src/utils/file_utils_test.py
+++ b/keras/src/utils/file_utils_test.py
@@ -58,7 +58,7 @@ def test_is_path_in_dir_with_absolute_paths(self):
class IsLinkInDirTest(test_case.TestCase):
def setUp(self):
- self._cleanup("test_path/to/base_dir")
+ self._cleanup(os.path.join("test_path", "to", "base_dir"))
self._cleanup("./base_dir")
def _cleanup(self, base_dir):
@@ -66,7 +66,7 @@ def _cleanup(self, base_dir):
shutil.rmtree(base_dir)
def test_is_link_in_dir_with_absolute_paths(self):
- base_dir = "test_path/to/base_dir"
+ base_dir = os.path.join("test_path", "to", "base_dir")
link_path = os.path.join(base_dir, "symlink")
target_path = os.path.join(base_dir, "file.txt")
@@ -120,7 +120,7 @@ def test_is_link_in_dir_with_relative_paths(self):
self.assertTrue(file_utils.is_link_in_dir(info, base_dir))
def tearDown(self):
- self._cleanup("test_path/to/base_dir")
+ self._cleanup(os.path.join("test_path", "to", "base_dir"))
self._cleanup("./base_dir")
@@ -319,7 +319,7 @@ def test_valid_tar_extraction(self):
"""Test valid tar.gz extraction and hash validation."""
dest_dir = self.get_temp_dir()
orig_dir = self.get_temp_dir()
- text_file_path, tar_file_path = self._create_tar_file(orig_dir)
+ _, tar_file_path = self._create_tar_file(orig_dir)
self._test_file_extraction_and_validation(
dest_dir, tar_file_path, "tar.gz"
)
@@ -328,7 +328,7 @@ def test_valid_zip_extraction(self):
"""Test valid zip extraction and hash validation."""
dest_dir = self.get_temp_dir()
orig_dir = self.get_temp_dir()
- text_file_path, zip_file_path = self._create_zip_file(orig_dir)
+ _, zip_file_path = self._create_zip_file(orig_dir)
self._test_file_extraction_and_validation(
dest_dir, zip_file_path, "zip"
)
@@ -348,7 +348,7 @@ def test_get_file_with_tgz_extension(self):
"""Test extraction of file with .tar.gz extension."""
dest_dir = self.get_temp_dir()
orig_dir = dest_dir
- text_file_path, tar_file_path = self._create_tar_file(orig_dir)
+ _, tar_file_path = self._create_tar_file(orig_dir)
origin = urllib.parse.urljoin(
"file://",
@@ -358,8 +358,8 @@ def test_get_file_with_tgz_extension(self):
path = file_utils.get_file(
"test.txt.tar.gz", origin, untar=True, cache_subdir=dest_dir
)
- self.assertTrue(path.endswith(".txt"))
self.assertTrue(os.path.exists(path))
+ self.assertTrue(os.path.exists(os.path.join(path, "test.txt")))
def test_get_file_with_integrity_check(self):
"""Test file download with integrity check."""
@@ -459,7 +459,7 @@ def _create_tar_file(self, directory):
text_file.write("Float like a butterfly, sting like a bee.")
with tarfile.open(tar_file_path, "w:gz") as tar_file:
- tar_file.add(text_file_path)
+ tar_file.add(text_file_path, arcname="test.txt")
return text_file_path, tar_file_path
@@ -471,7 +471,7 @@ def _create_zip_file(self, directory):
text_file.write("Float like a butterfly, sting like a bee.")
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
- zip_file.write(text_file_path)
+ zip_file.write(text_file_path, arcname="test.txt")
return text_file_path, zip_file_path
@@ -484,7 +484,6 @@ def _test_file_extraction_and_validation(
urllib.request.pathname2url(os.path.abspath(file_path)),
)
- hashval_sha256 = file_utils.hash_file(file_path)
hashval_md5 = file_utils.hash_file(file_path, algorithm="md5")
if archive_type:
@@ -499,17 +498,15 @@ def _test_file_extraction_and_validation(
extract=extract,
cache_subdir=dest_dir,
)
- path = file_utils.get_file(
- "test",
- origin,
- file_hash=hashval_sha256,
- extract=extract,
- cache_subdir=dest_dir,
- )
+ if extract:
+ fpath = path + "_archive"
+ else:
+ fpath = path
+
self.assertTrue(os.path.exists(path))
- self.assertTrue(file_utils.validate_file(path, hashval_sha256))
- self.assertTrue(file_utils.validate_file(path, hashval_md5))
- os.remove(path)
+ self.assertTrue(file_utils.validate_file(fpath, hashval_md5))
+ if extract:
+ self.assertTrue(os.path.exists(os.path.join(path, "test.txt")))
def test_exists(self):
temp_dir = self.get_temp_dir()
@@ -721,6 +718,9 @@ def test_cns_remote_path(self):
def test_cfs_remote_path(self):
self.assertTrue(file_utils.is_remote_path("/cfs/some/path"))
+ def test_readahead_remote_path(self):
+ self.assertTrue(file_utils.is_remote_path("/readahead/some/path"))
+
def test_non_remote_paths(self):
self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt"))
self.assertFalse(
diff --git a/keras/src/utils/image_dataset_utils.py b/keras/src/utils/image_dataset_utils.py
index 380b4337973f..c1918be73eef 100755
--- a/keras/src/utils/image_dataset_utils.py
+++ b/keras/src/utils/image_dataset_utils.py
@@ -83,15 +83,15 @@ def image_dataset_from_directory(
(must match names of subdirectories). Used to control the order
of the classes (otherwise alphanumerical order is used).
color_mode: One of `"grayscale"`, `"rgb"`, `"rgba"`.
- Defaults to `"rgb"`. Whether the images will be converted to
- have 1, 3, or 4 channels.
+ Whether the images will be converted to
+ have 1, 3, or 4 channels. Defaults to `"rgb"`.
batch_size: Size of the batches of data. Defaults to 32.
If `None`, the data will not be batched
(the dataset will yield individual samples).
image_size: Size to resize images to after they are read from disk,
- specified as `(height, width)`. Defaults to `(256, 256)`.
+ specified as `(height, width)`.
Since the pipeline processes batches of images that must all have
- the same size, this must be provided.
+ the same size, this must be provided. Defaults to `(256, 256)`.
shuffle: Whether to shuffle the data. Defaults to `True`.
If set to `False`, sorts the data in alphanumeric order.
seed: Optional random seed for shuffling and transformations.
@@ -103,9 +103,10 @@ def image_dataset_from_directory(
When `subset="both"`, the utility returns a tuple of two datasets
(the training and validation datasets respectively).
interpolation: String, the interpolation method used when
- resizing images. Defaults to `"bilinear"`.
+ resizing images.
Supports `"bilinear"`, `"nearest"`, `"bicubic"`, `"area"`,
`"lanczos3"`, `"lanczos5"`, `"gaussian"`, `"mitchellcubic"`.
+ Defaults to `"bilinear"`.
follow_links: Whether to visit subdirectories pointed to by symlinks.
Defaults to `False`.
crop_to_aspect_ratio: If `True`, resize the images without aspect
@@ -196,6 +197,14 @@ def image_dataset_from_directory(
f"Received: color_mode={color_mode}"
)
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ elif not isinstance(image_size, (list, tuple)) or not len(image_size) == 2:
+ raise ValueError(
+ "Invalid `image_size` value. Expected a tuple of 2 integers. "
+ f"Received: image_size={image_size}"
+ )
+
interpolation = interpolation.lower()
supported_interpolations = (
"bilinear",
diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py
index 8f5e805c5f75..ca8289c9f9b7 100644
--- a/keras/src/utils/image_utils.py
+++ b/keras/src/utils/image_utils.py
@@ -350,9 +350,9 @@ def smart_resize(
or `(batch_size, height, width, channels)`.
size: Tuple of `(height, width)` integer. Target size.
interpolation: String, interpolation to use for resizing.
- Defaults to `'bilinear'`.
- Supports `bilinear`, `nearest`, `bicubic`,
- `lanczos3`, `lanczos5`.
+ Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
+ `"lanczos3"`, `"lanczos5"`.
+ Defaults to `"bilinear"`.
data_format: `"channels_last"` or `"channels_first"`.
backend_module: Backend module to use (if different from the default
backend).
@@ -388,9 +388,9 @@ def smart_resize(
if isinstance(height, int) and isinstance(width, int):
# For JAX, we need to keep the slice indices as static integers
crop_height = int(float(width * target_height) / target_width)
- crop_height = min(height, crop_height)
+ crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
- crop_width = min(width, crop_width)
+ crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
else:
@@ -400,13 +400,16 @@ def smart_resize(
"int32",
)
crop_height = backend_module.numpy.minimum(height, crop_height)
+ crop_height = backend_module.numpy.maximum(crop_height, 1)
crop_height = backend_module.cast(crop_height, "int32")
+
crop_width = backend_module.cast(
backend_module.cast(height * target_width, "float32")
/ target_height,
"int32",
)
crop_width = backend_module.numpy.minimum(width, crop_width)
+ crop_width = backend_module.numpy.maximum(crop_width, 1)
crop_width = backend_module.cast(crop_width, "int32")
crop_box_hstart = backend_module.cast(
diff --git a/keras/src/utils/io_utils.py b/keras/src/utils/io_utils.py
index 32322f405c33..f087ab6dd21a 100644
--- a/keras/src/utils/io_utils.py
+++ b/keras/src/utils/io_utils.py
@@ -91,10 +91,18 @@ def set_logging_verbosity(level):
def print_msg(message, line_break=True):
"""Print the message to absl logging or stdout."""
+ message = str(message)
if is_interactive_logging_enabled():
- if line_break:
- sys.stdout.write(message + "\n")
- else:
+ message = message + "\n" if line_break else message
+ try:
+ sys.stdout.write(message)
+ except UnicodeEncodeError:
+ # If the encoding differs from UTF-8, `sys.stdout.write` may fail.
+ # To address this, replace special unicode characters in the
+ # message, and then encode and decode using the target encoding.
+ message = _replace_special_unicode_character(message)
+ message_bytes = message.encode(sys.stdout.encoding, errors="ignore")
+ message = message_bytes.decode(sys.stdout.encoding)
sys.stdout.write(message)
sys.stdout.flush()
else:
@@ -123,3 +131,8 @@ def ask_to_proceed_with_overwrite(filepath):
return False
print_msg("[TIP] Next time specify overwrite=True!")
return True
+
+
+def _replace_special_unicode_character(message):
+ message = str(message).replace("━", "=") # Fall back to Keras2 behavior.
+ return message
diff --git a/keras/src/utils/io_utils_test.py b/keras/src/utils/io_utils_test.py
index 235314de3016..2fe1fbbea219 100644
--- a/keras/src/utils/io_utils_test.py
+++ b/keras/src/utils/io_utils_test.py
@@ -1,3 +1,5 @@
+import sys
+import tempfile
from unittest.mock import patch
from keras.src.testing import test_case
@@ -55,3 +57,13 @@ def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _):
@patch("builtins.input", side_effect=["invalid", "n"])
def test_ask_to_proceed_with_overwrite_invalid_then_no(self, _):
self.assertFalse(io_utils.ask_to_proceed_with_overwrite("test_path"))
+
+ def test_print_msg_with_different_encoding(self):
+ # https://github.com/keras-team/keras/issues/19386
+ io_utils.enable_interactive_logging()
+ self.assertTrue(io_utils.is_interactive_logging_enabled())
+ ori_stdout = sys.stdout
+ with tempfile.TemporaryFile(mode="w", encoding="cp1251") as tmp:
+ sys.stdout = tmp
+ io_utils.print_msg("━")
+ sys.stdout = ori_stdout
diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py
index 9c97f0ac28d4..7776e7a5ba2a 100644
--- a/keras/src/utils/jax_layer.py
+++ b/keras/src/utils/jax_layer.py
@@ -5,6 +5,8 @@
from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
+from keras.src.backend.common.variables import is_float_dtype
+from keras.src.backend.common.variables import standardize_dtype
from keras.src.layers.layer import Layer
from keras.src.saving import serialization_lib
from keras.src.utils import jax_utils
@@ -192,7 +194,7 @@ def my_haiku_module_fn(inputs, training):
call_fn: The function to call the model. See description above for the
list of arguments it takes and the outputs it returns.
init_fn: the function to call to initialize the model. See description
- above for the list of arguments it takes and the ouputs it returns.
+ above for the list of arguments it takes and the outputs it returns.
If `None`, then `params` and/or `state` must be provided.
params: A `PyTree` containing all the model trainable parameters. This
allows passing trained parameters or controlling the initialization.
@@ -204,6 +206,8 @@ def my_haiku_module_fn(inputs, training):
argument, then `init_fn` is called at build time to initialize the
non-trainable state of the model.
seed: Seed for random number generator. Optional.
+ dtype: The dtype of the layer's computations and weights. Can also be a
+ `keras.DTypePolicy`. Optional. Defaults to the default policy.
"""
def __init__(
@@ -291,18 +295,28 @@ def _create_variables(self, values, trainable):
"""
def create_variable(value):
- if backend.is_tensor(value) or isinstance(value, np.ndarray):
- variable = self.add_weight(
- value.shape, initializer="zeros", trainable=trainable
+ if backend.is_tensor(value) or isinstance(
+ value, (np.ndarray, np.generic)
+ ):
+ dtype = value.dtype
+ if is_float_dtype(dtype):
+ dtype = None # Use the layer dtype policy
+ return self.add_weight(
+ value.shape,
+ initializer=value,
+ dtype=dtype,
+ trainable=trainable,
)
- variable.assign(value)
- return variable
- elif isinstance(value, (np.generic, int, float)):
- variable = self.add_weight(
- (), initializer="zeros", trainable=trainable
+ elif isinstance(value, (bool, int, float)):
+ dtype = standardize_dtype(type(value))
+ if is_float_dtype(dtype):
+ dtype = None # Use the layer dtype policy
+ return self.add_weight(
+ (),
+ initializer=backend.convert_to_tensor(value),
+ dtype=dtype,
+ trainable=trainable,
)
- variable.assign(value)
- return variable
else:
return value
diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py
index e3b088c78849..96c74809d13d 100644
--- a/keras/src/utils/jax_layer_test.py
+++ b/keras/src/utils/jax_layer_test.py
@@ -15,7 +15,7 @@
from keras.src import testing
from keras.src import tree
from keras.src import utils
-from keras.src.export import export_lib
+from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.saving import object_registration
from keras.src.utils.jax_layer import FlaxLayer
from keras.src.utils.jax_layer import JaxLayer
@@ -182,7 +182,7 @@ def from_config(cls, config):
backend.backend() != "jax",
reason="JaxLayer and FlaxLayer are only supported with JAX backend",
)
-class TestJaxLayer(testing.TestCase, parameterized.TestCase):
+class TestJaxLayer(testing.TestCase):
def _test_layer(
self,
model_name,
@@ -207,7 +207,6 @@ def _count_params(weights):
return count
def verify_weights_and_params(layer):
-
self.assertEqual(trainable_weights, len(layer.trainable_weights))
self.assertEqual(
trainable_params,
@@ -322,14 +321,20 @@ def verify_identical_model(model):
# export, load back and compare results
path = os.path.join(self.get_temp_dir(), "jax_layer_export")
- export_lib.export_model(model2, path)
+ model2.export(path, format="tf_saved_model")
model4 = tf.saved_model.load(path)
output4 = model4.serve(x_test)
- self.assertAllClose(output1, output4)
+ # The output difference is greater when using the GPU or bfloat16
+ lower_precision = testing.jax_uses_gpu() or "dtype" in layer_init_kwargs
+ self.assertAllClose(
+ output1,
+ output4,
+ atol=1e-2 if lower_precision else 1e-6,
+ rtol=1e-3 if lower_precision else 1e-6,
+ )
# test subclass model building without a build method
class TestModel(models.Model):
-
def __init__(self, layer):
super().__init__()
self._layer = layer
@@ -365,6 +370,18 @@ def call(self, inputs):
"non_trainable_weights": 1,
"non_trainable_params": 1,
},
+ {
+ "testcase_name": "training_state_dtype_policy",
+ "init_kwargs": {
+ "call_fn": jax_stateful_apply,
+ "init_fn": jax_stateful_init,
+ "dtype": DTypePolicy("mixed_float16"),
+ },
+ "trainable_weights": 6,
+ "trainable_params": 266610,
+ "non_trainable_weights": 1,
+ "non_trainable_params": 1,
+ },
)
def test_jax_layer(
self,
@@ -417,6 +434,19 @@ def test_jax_layer(
"non_trainable_weights": 8,
"non_trainable_params": 536,
},
+ {
+ "testcase_name": "training_rng_unbound_method_dtype_policy",
+ "flax_model_class": "FlaxDropoutModel",
+ "flax_model_method": None,
+ "init_kwargs": {
+ "method": "flax_dropout_wrapper",
+ "dtype": DTypePolicy("mixed_float16"),
+ },
+ "trainable_weights": 8,
+ "trainable_params": 648226,
+ "non_trainable_weights": 0,
+ "non_trainable_params": 0,
+ },
)
@pytest.mark.skipif(flax is None, reason="Flax library is not available.")
def test_flax_layer(
diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py
index 2ac944eb967d..d5375785f762 100644
--- a/keras/src/utils/jax_utils.py
+++ b/keras/src/utils/jax_utils.py
@@ -5,6 +5,7 @@ def is_in_jax_tracing_scope(x=None):
if backend.backend() == "jax":
if x is None:
x = backend.numpy.ones(())
- if x.__class__.__name__ == "DynamicJaxprTracer":
- return True
+ for c in x.__class__.__mro__:
+ if c.__name__ == "Tracer" and c.__module__.startswith("jax"):
+ return True
return False
diff --git a/keras/src/utils/model_visualization.py b/keras/src/utils/model_visualization.py
index 1fd180339b14..1fd539961ba6 100644
--- a/keras/src/utils/model_visualization.py
+++ b/keras/src/utils/model_visualization.py
@@ -8,16 +8,15 @@
from keras.src.utils import io_utils
try:
- # pydot-ng is a fork of pydot that is better maintained.
- import pydot_ng as pydot
+ import pydot
except ImportError:
- # pydotplus is an improved version of pydot
+ # pydot_ng and pydotplus are older forks of pydot
+ # which may still be used by some users
try:
- import pydotplus as pydot
+ import pydot_ng as pydot
except ImportError:
- # Fall back on pydot if necessary.
try:
- import pydot
+ import pydotplus as pydot
except ImportError:
pydot = None
@@ -36,7 +35,7 @@ def check_graphviz():
# to check the pydot/graphviz installation.
pydot.Dot.create(pydot.Dot())
return True
- except (OSError, pydot.InvocationException):
+ except (OSError, pydot.PydotException):
return False
@@ -150,7 +149,7 @@ def format_shape(shape):
cols.append(
(
''
- f'Output dtype: {dtype or "?"}'
+ f"Output dtype: {dtype or '?'}"
" | "
)
)
@@ -190,6 +189,14 @@ def make_node(layer, **kwargs):
return node
+def remove_unused_edges(dot):
+ nodes = [v.get_name() for v in dot.get_nodes()]
+ for edge in dot.get_edges():
+ if edge.get_destination() not in nodes:
+ dot.del_edge(edge.get_source(), edge.get_destination())
+ return dot
+
+
@keras_export("keras.utils.model_to_dot")
def model_to_dot(
model,
@@ -460,6 +467,7 @@ def plot_model(
to_file = str(to_file)
if dot is None:
return
+ dot = remove_unused_edges(dot)
_, extension = os.path.splitext(to_file)
if not extension:
extension = "png"
diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py
index c0991fd6bed3..190bc8dc72fe 100644
--- a/keras/src/utils/module_utils.py
+++ b/keras/src/utils/module_utils.py
@@ -2,10 +2,13 @@
class LazyModule:
- def __init__(self, name, pip_name=None):
+ def __init__(self, name, pip_name=None, import_error_msg=None):
self.name = name
- pip_name = pip_name or name
- self.pip_name = pip_name
+ self.pip_name = pip_name or name
+ self.import_error_msg = import_error_msg or (
+ f"This requires the {self.name} module. "
+ f"You can install it via `pip install {self.pip_name}`"
+ )
self.module = None
self._available = None
@@ -23,10 +26,7 @@ def initialize(self):
try:
self.module = importlib.import_module(self.name)
except ImportError:
- raise ImportError(
- f"This requires the {self.name} module. "
- f"You can install it via `pip install {self.pip_name}`"
- )
+ raise ImportError(self.import_error_msg)
def __getattr__(self, name):
if name == "_api_export_path":
@@ -35,11 +35,27 @@ def __getattr__(self, name):
self.initialize()
return getattr(self.module, name)
+ def __repr__(self):
+ return f"LazyModule({self.name})"
+
tensorflow = LazyModule("tensorflow")
gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
tensorflow_io = LazyModule("tensorflow_io")
scipy = LazyModule("scipy")
jax = LazyModule("jax")
+torchvision = LazyModule("torchvision")
+torch_xla = LazyModule(
+ "torch_xla",
+ import_error_msg=(
+ "This requires the torch_xla module. You can install it via "
+ "`pip install torch-xla`. Additionally, you may need to update "
+ "LD_LIBRARY_PATH if necessary. Torch XLA builds a shared library, "
+ "_XLAC.so, which needs to link to the version of Python it was built "
+ "with. Use the following command to update LD_LIBRARY_PATH: "
+ "`export LD_LIBRARY_PATH=/lib:$LD_LIBRARY_PATH`"
+ ),
+)
optree = LazyModule("optree")
dmtree = LazyModule("tree")
+tf2onnx = LazyModule("tf2onnx")
diff --git a/keras/src/utils/numerical_utils.py b/keras/src/utils/numerical_utils.py
index 05fb82abc522..dcd2cc422d6a 100644
--- a/keras/src/utils/numerical_utils.py
+++ b/keras/src/utils/numerical_utils.py
@@ -2,6 +2,7 @@
from keras.src import backend
from keras.src.api_export import keras_export
+from keras.src.utils import tf_utils
@keras_export("keras.utils.normalize")
@@ -73,6 +74,15 @@ def to_categorical(x, num_classes=None):
[0. 0. 0. 0.]
"""
if backend.is_tensor(x):
+ input_shape = backend.core.shape(x)
+ # Shrink the last dimension if the shape is (..., 1).
+ if (
+ input_shape is not None
+ and len(input_shape) > 1
+ and input_shape[-1] == 1
+ ):
+ newshape = tuple(input_shape[:-1])
+ x = backend.numpy.reshape(x, newshape)
return backend.nn.one_hot(x, num_classes)
x = np.array(x, dtype="int64")
input_shape = x.shape
@@ -96,48 +106,120 @@ def encode_categorical_inputs(
inputs,
output_mode,
depth,
- dtype="float32",
+ dtype,
+ sparse=False,
+ count_weights=None,
backend_module=None,
):
- """Encodes categorical inputs according to output_mode."""
+ """Encodes categorical inputs according to output_mode.
+
+ Args:
+ inputs: the inputs to encode.
+ output_mode: one of `"int"`, `"one_hot"`, `"multi_hot"`, or `"count"`.
+ depth: number of classes, this will be the last dimension of the output.
+ dtype: the dtype of the output, unless `count_weights` is not `None`.
+ sparse: whether the output should be sparse for backends supporting it.
+ count_weights: weights to apply if `output_mode` is `"count"`.
+ backend_module: the backend to use instead of the current one.
+
+ Returns: the encoded inputs.
+ """
backend_module = backend_module or backend
if output_mode == "int":
return backend_module.cast(inputs, dtype=dtype)
- binary_output = output_mode in ("multi_hot", "one_hot")
- original_shape = backend_module.shape(inputs)
- rank_of_inputs = len(original_shape)
+ rank_of_inputs = len(backend_module.shape(inputs))
# In all cases, we should uprank scalar input to a single sample.
if rank_of_inputs == 0:
- # We need to update `rank_of_inputs`
- # If necessary.
inputs = backend_module.numpy.expand_dims(inputs, -1)
- elif rank_of_inputs > 2:
- # The `count` mode does not support inputs with a rank greater than 2.
- if not binary_output:
- raise ValueError(
- "When output_mode is anything other than "
- "`'multi_hot', 'one_hot', or 'int'`, "
- "the rank must be 2 or less. "
- f"Received output_mode: {output_mode} "
- f"and input shape: {original_shape}, "
- f"which would result in output rank {rank_of_inputs}."
+ rank_of_inputs = 1
+
+ if (
+ backend_module.__name__.endswith("tensorflow")
+ and rank_of_inputs <= 2
+ and output_mode in ("multi_hot", "count")
+ ):
+ # TF only fastpath. Uses bincount; faster. Doesn't work for rank 3+.
+ try:
+ return tf_utils.tf_encode_categorical_inputs(
+ inputs,
+ output_mode,
+ depth,
+ dtype=dtype,
+ sparse=sparse,
+ count_weights=count_weights,
)
+ except ValueError:
+ pass
- if binary_output:
- if output_mode == "one_hot":
- bincounts = backend_module.nn.one_hot(inputs, depth)
- elif output_mode == "multi_hot":
- one_hot_input = backend_module.nn.one_hot(inputs, depth)
- bincounts = backend_module.numpy.where(
- backend_module.numpy.any(one_hot_input, axis=-2), 1, 0
- )
- else:
- bincounts = backend_module.numpy.bincount(
- inputs,
- minlength=depth,
+ if output_mode == "multi_hot":
+ return backend_module.nn.multi_hot(
+ inputs, depth, dtype=dtype, sparse=sparse
+ )
+ elif output_mode == "one_hot":
+ input_shape = backend_module.core.shape(inputs)
+ # Shrink the last dimension if the shape is (..., 1).
+ if (
+ input_shape is not None
+ and len(input_shape) > 1
+ and input_shape[-1] == 1
+ ):
+ newshape = tuple(input_shape[:-1])
+ inputs = backend_module.numpy.reshape(inputs, newshape)
+ return backend_module.nn.one_hot(
+ inputs, depth, dtype=dtype, sparse=sparse
+ )
+ elif output_mode == "count":
+ # We don't use `ops.bincount` because its output has a dynamic shape
+ # (last dimension is the highest value of `inputs`). We implement a
+ # narrower use case where `minlength` and `maxlength` (not supported by
+ # `ops.bincount`) are the same and static value: `depth`. We also don't
+ # need to support indices that are negative or greater than `depth`.
+ reduction_axis = 1 if len(inputs.shape) > 1 else 0
+
+ if count_weights is not None:
+ dtype = count_weights.dtype
+ one_hot_encoding = backend_module.nn.one_hot(
+ inputs, depth, dtype=dtype, sparse=sparse
+ )
+ if count_weights is not None:
+ count_weights = backend_module.numpy.expand_dims(count_weights, -1)
+ one_hot_encoding = one_hot_encoding * count_weights
+
+ outputs = backend_module.numpy.sum(
+ one_hot_encoding,
+ axis=reduction_axis,
+ )
+ return outputs
+
+
+def build_pos_neg_masks(
+ query_labels,
+ key_labels,
+ remove_diagonal=True,
+):
+ from keras.src import ops
+
+ if ops.ndim(query_labels) == 1:
+ query_labels = ops.reshape(query_labels, (-1, 1))
+
+ if ops.ndim(key_labels) == 1:
+ key_labels = ops.reshape(key_labels, (-1, 1))
+
+ positive_mask = ops.equal(query_labels, ops.transpose(key_labels))
+ negative_mask = ops.logical_not(positive_mask)
+
+ if remove_diagonal:
+ positive_mask = ops.logical_and(
+ positive_mask,
+ ~ops.eye(
+ ops.size(query_labels),
+ ops.size(key_labels),
+ k=0,
+ dtype="bool",
+ ),
)
- bincounts = backend_module.cast(bincounts, dtype)
- return bincounts
+
+ return positive_mask, negative_mask
diff --git a/keras/src/utils/numerical_utils_test.py b/keras/src/utils/numerical_utils_test.py
index 2cb8c4c5e782..9b9520abc90e 100644
--- a/keras/src/utils/numerical_utils_test.py
+++ b/keras/src/utils/numerical_utils_test.py
@@ -8,7 +8,7 @@
NUM_CLASSES = 5
-class TestNumericalUtils(testing.TestCase, parameterized.TestCase):
+class TestNumericalUtils(testing.TestCase):
@parameterized.parameters(
[
((1,), (1, NUM_CLASSES)),
@@ -31,7 +31,7 @@ def test_to_categorical(self, shape, expected_shape):
np.all(np.argmax(one_hot, -1).reshape(label.shape) == label)
)
- def test_to_categorial_without_num_classes(self):
+ def test_to_categorical_without_num_classes(self):
label = [0, 2, 5]
one_hot = numerical_utils.to_categorical(label)
self.assertEqual(one_hot.shape, (3, 5 + 1))
@@ -72,3 +72,80 @@ def test_normalize(self, order):
out = numerical_utils.normalize(xb, axis=-1, order=order)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(backend.convert_to_numpy(out), expected)
+
+ def test_build_pos_neg_masks(self):
+ query_labels = np.array([0, 1, 2, 2, 0])
+ key_labels = np.array([0, 1, 2, 0, 2])
+ expected_shape = (len(query_labels), len(key_labels))
+
+ positive_mask, negative_mask = numerical_utils.build_pos_neg_masks(
+ query_labels, key_labels, remove_diagonal=False
+ )
+
+ positive_mask = backend.convert_to_numpy(positive_mask)
+ negative_mask = backend.convert_to_numpy(negative_mask)
+ self.assertEqual(positive_mask.shape, expected_shape)
+ self.assertEqual(negative_mask.shape, expected_shape)
+ self.assertTrue(
+ np.all(np.logical_not(np.logical_and(positive_mask, negative_mask)))
+ )
+
+ expected_positive_mask_keep_diag = np.array(
+ [
+ [1, 0, 0, 1, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 1, 0],
+ ],
+ dtype="bool",
+ )
+ self.assertTrue(
+ np.all(positive_mask == expected_positive_mask_keep_diag)
+ )
+ self.assertTrue(
+ np.all(
+ negative_mask
+ == np.logical_not(expected_positive_mask_keep_diag)
+ )
+ )
+
+ positive_mask, negative_mask = numerical_utils.build_pos_neg_masks(
+ query_labels, key_labels, remove_diagonal=True
+ )
+ positive_mask = backend.convert_to_numpy(positive_mask)
+ negative_mask = backend.convert_to_numpy(negative_mask)
+ self.assertEqual(positive_mask.shape, expected_shape)
+ self.assertEqual(negative_mask.shape, expected_shape)
+ self.assertTrue(
+ np.all(np.logical_not(np.logical_and(positive_mask, negative_mask)))
+ )
+
+ expected_positive_mask_with_remove_diag = np.array(
+ [
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 1, 0],
+ ],
+ dtype="bool",
+ )
+ self.assertTrue(
+ np.all(positive_mask == expected_positive_mask_with_remove_diag)
+ )
+
+ query_labels = np.array([1, 2, 3])
+ key_labels = np.array([1, 2, 3, 1])
+
+ positive_mask, negative_mask = numerical_utils.build_pos_neg_masks(
+ query_labels, key_labels, remove_diagonal=True
+ )
+ positive_mask = backend.convert_to_numpy(positive_mask)
+ negative_mask = backend.convert_to_numpy(negative_mask)
+ expected_shape_diff_sizes = (len(query_labels), len(key_labels))
+ self.assertEqual(positive_mask.shape, expected_shape_diff_sizes)
+ self.assertEqual(negative_mask.shape, expected_shape_diff_sizes)
+ self.assertTrue(
+ np.all(np.logical_not(np.logical_and(positive_mask, negative_mask)))
+ )
diff --git a/keras/src/utils/python_utils.py b/keras/src/utils/python_utils.py
index d1146b4818b4..312871675b80 100644
--- a/keras/src/utils/python_utils.py
+++ b/keras/src/utils/python_utils.py
@@ -148,3 +148,30 @@ def remove_by_id(lst, value):
if id(v) == id(value):
del lst[i]
return
+
+
+def pythonify_logs(logs):
+ """Flatten and convert log values to Python-native types.
+
+ This function attempts to convert dict value by `float(value)` and skips
+ the conversion if it fails.
+
+ Args:
+ logs: A dict containing log values.
+
+ Returns:
+ A flattened dict with values converted to Python-native types if
+ possible.
+ """
+ logs = logs or {}
+ result = {}
+ for key, value in sorted(logs.items()):
+ if isinstance(value, dict):
+ result.update(pythonify_logs(value))
+ else:
+ try:
+ value = float(value)
+ except:
+ pass
+ result[key] = value
+ return result
diff --git a/keras/src/utils/sequence_utils.py b/keras/src/utils/sequence_utils.py
index 8d0c67a08e08..cfb27ef25de6 100644
--- a/keras/src/utils/sequence_utils.py
+++ b/keras/src/utils/sequence_utils.py
@@ -69,7 +69,7 @@ def pad_sequences(
truncating: String, "pre" or "post" (optional, defaults to `"pre"`):
remove values from sequences larger than
`maxlen`, either at the beginning or at the end of the sequences.
- value: Float or String, padding value. (Optional, defaults to 0.)
+ value: Float or String, padding value. (Optional, defaults to `0.`)
Returns:
NumPy array with shape `(len(sequences), maxlen)`
@@ -101,9 +101,9 @@ def pad_sequences(
maxlen = np.max(lengths)
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
- dtype, np.unicode_
+ dtype, np.str_
)
- if isinstance(value, str) and dtype != object and not is_dtype_str:
+ if isinstance(value, str) and dtype is not object and not is_dtype_str:
raise ValueError(
f"`dtype` {dtype} is not compatible with `value`'s type: "
f"{type(value)}\nYou should set `dtype=object` for variable length "
diff --git a/keras/src/utils/summary_utils.py b/keras/src/utils/summary_utils.py
index 94c82af7ff84..2e67b5c2f841 100644
--- a/keras/src/utils/summary_utils.py
+++ b/keras/src/utils/summary_utils.py
@@ -1,3 +1,4 @@
+import functools
import math
import re
import shutil
@@ -21,6 +22,14 @@ def count_params(weights):
return int(sum(math.prod(p) for p in shapes))
+@functools.lru_cache(512)
+def _compute_memory_size(shape, dtype):
+ weight_counts = math.prod(shape)
+ dtype = backend.standardize_dtype(dtype)
+ per_param_size = dtype_utils.dtype_size(dtype)
+ return weight_counts * per_param_size
+
+
def weight_memory_size(weights):
"""Compute the memory footprint for weights based on their dtypes.
@@ -33,10 +42,7 @@ def weight_memory_size(weights):
unique_weights = {id(w): w for w in weights}.values()
total_memory_size = 0
for w in unique_weights:
- weight_shape = math.prod(w.shape)
- dtype = backend.standardize_dtype(w.dtype)
- per_param_size = dtype_utils.dtype_size(dtype)
- total_memory_size += weight_shape * per_param_size
+ total_memory_size += _compute_memory_size(w.shape, w.dtype)
return total_memory_size / 8
@@ -76,18 +82,35 @@ def bold_text(x, color=None):
def format_layer_shape(layer):
- if not layer._inbound_nodes:
+ if not layer._inbound_nodes and not layer._build_shapes_dict:
return "?"
def format_shape(shape):
highlighted = [highlight_number(x) for x in shape]
return "(" + ", ".join(highlighted) + ")"
- for i in range(len(layer._inbound_nodes)):
- outputs = layer._inbound_nodes[i].output_tensors
- output_shapes = tree.map_structure(
- lambda x: format_shape(x.shape), outputs
- )
+ # There are 2 approaches to get output shapes:
+ # 1. Using `layer._inbound_nodes`, which is possible if the model is a
+ # Sequential or Functional.
+ # 2. Using `layer._build_shapes_dict`, which is possible if users manually
+ # build the layer.
+ if len(layer._inbound_nodes) > 0:
+ for i in range(len(layer._inbound_nodes)):
+ outputs = layer._inbound_nodes[i].output_tensors
+ output_shapes = tree.map_structure(
+ lambda x: format_shape(x.shape), outputs
+ )
+ else:
+ try:
+ if hasattr(layer, "output_shape"):
+ output_shapes = format_shape(layer.output_shape)
+ else:
+ outputs = layer.compute_output_shape(**layer._build_shapes_dict)
+ output_shapes = tree.map_shape_structure(
+ lambda x: format_shape(x), outputs
+ )
+ except NotImplementedError:
+ return "?"
if len(output_shapes) == 1:
return output_shapes[0]
out = str(output_shapes)
@@ -261,7 +284,7 @@ def get_layer_fields(layer, prefix=""):
if not sequential_like:
fields.append(get_connections(layer))
if show_trainable:
- if layer.weights:
+ if hasattr(layer, "weights") and len(layer.weights) > 0:
fields.append(
bold_text("Y", color=34)
if layer.trainable
diff --git a/keras/src/utils/summary_utils_test.py b/keras/src/utils/summary_utils_test.py
index 7b917da4f848..bda3ed571260 100644
--- a/keras/src/utils/summary_utils_test.py
+++ b/keras/src/utils/summary_utils_test.py
@@ -4,11 +4,12 @@
from keras.src import layers
from keras.src import models
+from keras.src import ops
from keras.src import testing
from keras.src.utils import summary_utils
-class SummaryUtilsTest(testing.TestCase, parameterized.TestCase):
+class SummaryUtilsTest(testing.TestCase):
@parameterized.parameters([("adam",), (None,)])
@pytest.mark.requires_trainable_backend
def test_print_model_summary(self, optimizer):
@@ -40,3 +41,86 @@ def print_to_variable(text, line_break=False):
self.assertNotIn("Optimizer params", summary_content)
except ImportError:
pass
+
+ def test_print_model_summary_custom_build(self):
+ class MyModel(models.Model):
+ def __init__(self):
+ super().__init__()
+ self.dense1 = layers.Dense(4, activation="relu")
+ self.dense2 = layers.Dense(2, activation="softmax")
+ self.unbuilt_dense = layers.Dense(1)
+
+ def build(self, input_shape):
+ self.dense1.build(input_shape)
+ input_shape = self.dense1.compute_output_shape(input_shape)
+ self.dense2.build(input_shape)
+
+ def call(self, inputs):
+ x = self.dense1(inputs)
+ return self.dense2(x)
+
+ model = MyModel()
+ model.build((None, 2))
+
+ summary_content = []
+
+ def print_to_variable(text, line_break=False):
+ summary_content.append(text)
+
+ summary_utils.print_summary(model, print_fn=print_to_variable)
+ summary_content = "\n".join(summary_content)
+ self.assertIn("(None, 4)", summary_content) # dense1
+ self.assertIn("(None, 2)", summary_content) # dense2
+ self.assertIn("?", summary_content) # unbuilt_dense
+ self.assertIn("Total params: 22", summary_content)
+ self.assertIn("Trainable params: 22", summary_content)
+ self.assertIn("Non-trainable params: 0", summary_content)
+
+ def test_print_model_summary_op_as_layer(self):
+ inputs = layers.Input((2,))
+ x = layers.Dense(4)(inputs)
+ outputs = ops.mean(x)
+ model = models.Model(inputs, outputs)
+
+ summary_content = []
+
+ def print_to_variable(text, line_break=False):
+ summary_content.append(text)
+
+ summary_utils.print_summary(
+ model, print_fn=print_to_variable, show_trainable=True
+ )
+ summary_content = "\n".join(summary_content)
+ self.assertIn("(None, 4)", summary_content) # dense
+ self.assertIn("Y", summary_content) # dense
+ self.assertIn("()", summary_content) # mean
+ self.assertIn("-", summary_content) # mean
+ self.assertIn("Total params: 12", summary_content)
+ self.assertIn("Trainable params: 12", summary_content)
+ self.assertIn("Non-trainable params: 0", summary_content)
+
+ def test_print_model_summary_with_mha(self):
+ # In Keras <= 3.6, MHA exposes `output_shape` property which breaks this
+ # test.
+ class MyModel(models.Model):
+ def __init__(self):
+ super().__init__()
+ self.mha = layers.MultiHeadAttention(2, 2, output_shape=(4,))
+
+ def call(self, inputs):
+ return self.mha(inputs, inputs, inputs)
+
+ model = MyModel()
+ model(np.ones((1, 2, 2)))
+
+ summary_content = []
+
+ def print_to_variable(text, line_break=False):
+ summary_content.append(text)
+
+ summary_utils.print_summary(model, print_fn=print_to_variable)
+ summary_content = "\n".join(summary_content)
+ self.assertIn("(1, 2, 4)", summary_content) # mha
+ self.assertIn("Total params: 56", summary_content)
+ self.assertIn("Trainable params: 56", summary_content)
+ self.assertIn("Non-trainable params: 0", summary_content)
diff --git a/keras/src/utils/text_dataset_utils.py b/keras/src/utils/text_dataset_utils.py
index ab1272bf190d..a76134818570 100644
--- a/keras/src/utils/text_dataset_utils.py
+++ b/keras/src/utils/text_dataset_utils.py
@@ -72,13 +72,15 @@ def text_dataset_from_directory(
This is the explicit list of class names
(must match names of subdirectories). Used to control the order
of the classes (otherwise alphanumerical order is used).
- batch_size: Size of the batches of data. Defaults to 32.
+ batch_size: Size of the batches of data.
If `None`, the data will not be batched
(the dataset will yield individual samples).
+ Defaults to `32`.
max_length: Maximum size of a text string. Texts longer than this will
be truncated to `max_length`.
- shuffle: Whether to shuffle the data. Defaults to `True`.
+ shuffle: Whether to shuffle the data.
If set to `False`, sorts the data in alphanumeric order.
+ Defaults to `True`.
seed: Optional random seed for shuffling and transformations.
validation_split: Optional float between 0 and 1,
fraction of data to reserve for validation.
diff --git a/keras/src/utils/tf_utils.py b/keras/src/utils/tf_utils.py
index ea8e45aaf7c7..485cc2c1362c 100644
--- a/keras/src/utils/tf_utils.py
+++ b/keras/src/utils/tf_utils.py
@@ -1,12 +1,47 @@
+from keras.src import backend
from keras.src.utils.module_utils import tensorflow as tf
-def expand_dims(inputs, axis):
- """Expand dims on sparse, ragged, or dense tensors."""
- if isinstance(inputs, tf.SparseTensor):
- return tf.sparse.expand_dims(inputs, axis)
+def get_tensor_spec(t, dynamic_batch=False, name=None):
+ """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
+ if isinstance(t, tf.TypeSpec):
+ spec = t
+ elif isinstance(t, tf.__internal__.CompositeTensor):
+ # Check for ExtensionTypes
+ spec = t._type_spec
+ elif hasattr(t, "shape") and hasattr(t, "dtype"):
+ spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
else:
- return tf.expand_dims(inputs, axis)
+ return None # Allow non-Tensors to pass through.
+
+ if not dynamic_batch:
+ return spec
+
+ shape = spec.shape
+ if shape.rank is None or shape.rank == 0:
+ return spec
+
+ shape_list = shape.as_list()
+ shape_list[0] = None
+ shape = tf.TensorShape(shape_list)
+ spec._shape = shape
+ return spec
+
+
+def ensure_tensor(inputs, dtype=None):
+ """Ensures the input is a Tensor, SparseTensor or RaggedTensor."""
+ if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)):
+ if backend.backend() == "torch" and backend.is_tensor(inputs):
+ # Plain `np.asarray()` conversion fails with PyTorch.
+ inputs = backend.convert_to_numpy(inputs)
+ inputs = tf.convert_to_tensor(inputs, dtype)
+ if dtype is not None and inputs.dtype != dtype:
+ inputs = tf.cast(inputs, dtype)
+ return inputs
+
+
+def is_ragged_tensor(x):
+ return "ragged_tensor.RaggedTensor" in str(type(x))
def sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None):
@@ -50,7 +85,14 @@ def dense_bincount(inputs, depth, binary_output, dtype, count_weights=None):
return result
-def encode_categorical_inputs(
+def expand_dims(inputs, axis):
+ """Expand dims on sparse, ragged, or dense tensors."""
+ if isinstance(inputs, tf.SparseTensor):
+ return tf.sparse.expand_dims(inputs, axis)
+ return tf.expand_dims(inputs, axis)
+
+
+def tf_encode_categorical_inputs(
inputs,
output_mode,
depth,
@@ -59,7 +101,11 @@ def encode_categorical_inputs(
count_weights=None,
idf_weights=None,
):
- """Encodes categoical inputs according to output_mode."""
+ """Encodes categorical inputs according to output_mode.
+
+ Faster method that relies on bincount.
+ """
+
if output_mode == "int":
return tf.identity(tf.cast(inputs, dtype))
@@ -72,7 +118,6 @@ def encode_categorical_inputs(
if inputs.shape[-1] != 1:
inputs = expand_dims(inputs, -1)
- # TODO(b/190445202): remove output rank restriction.
if inputs.shape.rank > 2:
raise ValueError(
"When output_mode is not `'int'`, maximum supported output rank "
@@ -91,6 +136,7 @@ def encode_categorical_inputs(
inputs, depth, binary_output, dtype, count_weights
)
+ bincounts = tf.cast(bincounts, dtype)
if output_mode != "tf_idf":
return bincounts
@@ -108,39 +154,4 @@ def encode_categorical_inputs(
bincounts.dense_shape,
)
else:
- return tf.multiply(tf.cast(bincounts, idf_weights.dtype), idf_weights)
-
-
-def get_tensor_spec(t, dynamic_batch=False, name=None):
- """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
- if isinstance(t, tf.TypeSpec):
- spec = t
- elif isinstance(t, tf.__internal__.CompositeTensor):
- # Check for ExtensionTypes
- spec = t._type_spec
- elif hasattr(t, "shape") and hasattr(t, "dtype"):
- spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
- else:
- return None # Allow non-Tensors to pass through.
-
- if not dynamic_batch:
- return spec
-
- shape = spec.shape
- if shape.rank is None or shape.rank == 0:
- return spec
-
- shape_list = shape.as_list()
- shape_list[0] = None
- shape = tf.TensorShape(shape_list)
- spec._shape = shape
- return spec
-
-
-def ensure_tensor(inputs, dtype=None):
- """Ensures the input is a Tensor, SparseTensor or RaggedTensor."""
- if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)):
- inputs = tf.convert_to_tensor(inputs, dtype)
- if dtype is not None and inputs.dtype != dtype:
- inputs = tf.cast(inputs, dtype)
- return inputs
+ return tf.multiply(bincounts, idf_weights)
diff --git a/keras/src/utils/timeseries_dataset_utils_test.py b/keras/src/utils/timeseries_dataset_utils_test.py
index 98c75a425e3c..251b81cd3589 100644
--- a/keras/src/utils/timeseries_dataset_utils_test.py
+++ b/keras/src/utils/timeseries_dataset_utils_test.py
@@ -88,7 +88,7 @@ def test_shuffle(self):
# results
for x, _ in dataset.take(1):
self.assertNotAllClose(x, first_seq)
- # Check determism with same seed
+ # Check determinism with same seed
dataset = timeseries_dataset_utils.timeseries_dataset_from_array(
data,
targets,
diff --git a/keras/src/utils/torch_utils.py b/keras/src/utils/torch_utils.py
index 11cc136f508f..ceed2425ea25 100644
--- a/keras/src/utils/torch_utils.py
+++ b/keras/src/utils/torch_utils.py
@@ -2,6 +2,7 @@
from packaging.version import parse
+from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.layers import Layer
from keras.src.ops import convert_to_numpy
@@ -16,6 +17,9 @@ class TorchModuleWrapper(Layer):
`torch.nn.Module` into a Keras layer, in particular by making its
parameters trackable by Keras.
+ `TorchModuleWrapper` is only compatible with the PyTorch backend and
+ cannot be used with the TensorFlow or JAX backends.
+
Args:
module: `torch.nn.Module` instance. If it's a `LazyModule`
instance, then its parameters must be initialized before
@@ -29,11 +33,12 @@ class TorchModuleWrapper(Layer):
PyTorch modules.
```python
+ import torch
import torch.nn as nn
import torch.nn.functional as F
import keras
- from keras.src.layers import TorchModuleWrapper
+ from keras.layers import TorchModuleWrapper
class Classifier(keras.Model):
def __init__(self, **kwargs):
@@ -98,18 +103,20 @@ def parameters(self, recurse=True):
return self.module.parameters(recurse=recurse)
def _track_module_parameters(self):
- from keras.src.backend.torch import Variable
-
for param in self.module.parameters():
# The Variable will reuse the raw `param`
# and simply wrap it.
- variable = Variable(
+ variable = backend.Variable(
initializer=param, trainable=param.requires_grad
)
self._track_variable(variable)
self.built = True
- def call(self, *args, **kwargs):
+ def call(self, *args, training=None, **kwargs):
+ if training is False:
+ self.eval()
+ else:
+ self.train()
return self.module(*args, **kwargs)
def save_own_variables(self, store):
diff --git a/keras/src/utils/torch_utils_test.py b/keras/src/utils/torch_utils_test.py
index 7e972f5b1b56..1be561d94f5e 100644
--- a/keras/src/utils/torch_utils_test.py
+++ b/keras/src/utils/torch_utils_test.py
@@ -29,9 +29,9 @@ def __init__(
self.torch_wrappers.append(TorchModuleWrapper(torch_model))
self.fc = layers.Dense(1)
- def call(self, x):
+ def call(self, x, training=None):
for wrapper in self.torch_wrappers:
- x = wrapper(x)
+ x = wrapper(x, training=training)
return self.fc(x)
def get_config(self):
@@ -49,14 +49,14 @@ def __init__(self, *args, **kwargs):
self.fc2 = torch.nn.Linear(4, 4)
self.fc3 = layers.Dense(2)
- def call(self, x):
+ def call(self, x, training=None):
return self.fc3(self.fc2(self.bn1(self.fc1(x))))
@pytest.mark.skipif(
backend.backend() != "torch", reason="Requires torch backend"
)
-class TorchUtilsTest(testing.TestCase, parameterized.TestCase):
+class TorchUtilsTest(testing.TestCase):
@parameterized.parameters(
{"use_batch_norm": False, "num_torch_layers": 1},
{"use_batch_norm": True, "num_torch_layers": 1},
@@ -82,6 +82,50 @@ def test_basic_usage(self, use_batch_norm, num_torch_layers):
model.compile(optimizer="sgd", loss="mse")
model.fit(np.random.random((3, 2)), np.random.random((3, 1)))
+ @parameterized.named_parameters(
+ (
+ "explicit_torch_wrapper",
+ Classifier,
+ {"use_batch_norm": True, "num_torch_layers": 1},
+ ),
+ ("implicit_torch_wrapper", ClassifierWithNoSpecialCasing, {}),
+ )
+ def test_training_args(self, cls, kwargs):
+ model = cls(**kwargs)
+ model(np.random.random((3, 2)), training=False) # Eager call to build
+ ref_weights = model.get_weights()
+ ref_running_mean = backend.convert_to_numpy(
+ model.torch_wrappers[0].module[-1].running_mean
+ if cls is Classifier
+ else model.bn1.module.running_mean
+ )
+
+ # Test training=False doesn't affect model weights
+ model(np.random.random((3, 2)), training=False)
+ weights = model.get_weights()
+ for w, ref_w in zip(weights, ref_weights):
+ self.assertAllClose(w, ref_w)
+
+ # Test training=None affects BN's stats
+ model.set_weights(ref_weights) # Restore previous weights
+ model(np.random.random((3, 2)))
+ running_mean = backend.convert_to_numpy(
+ model.torch_wrappers[0].module[-1].running_mean
+ if cls is Classifier
+ else model.bn1.module.running_mean
+ )
+ self.assertNotAllClose(running_mean, ref_running_mean)
+
+ # Test training=True affects BN's stats
+ model.set_weights(ref_weights) # Restore previous weights
+ model(np.random.random((3, 2)), training=True)
+ running_mean = backend.convert_to_numpy(
+ model.torch_wrappers[0].module[-1].running_mean
+ if cls is Classifier
+ else model.bn1.module.running_mean
+ )
+ self.assertNotAllClose(running_mean, ref_running_mean)
+
def test_module_autowrapping(self):
model = ClassifierWithNoSpecialCasing()
self.assertIsInstance(model.fc1, TorchModuleWrapper)
diff --git a/keras/src/utils/tracking.py b/keras/src/utils/tracking.py
index d24cfc3836a6..a2a26679937a 100644
--- a/keras/src/utils/tracking.py
+++ b/keras/src/utils/tracking.py
@@ -185,12 +185,12 @@ def __delitem__(self, index):
self.tracker.untrack(value)
def tree_flatten(self):
- # For optree
+ # For optree / dmtree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
- # For optree
+ # For optree / dmtree
return cls(children)
@@ -234,20 +234,15 @@ def clear(self):
super().clear()
def tree_flatten(self):
- from keras.src.utils.module_utils import optree
-
- # For optree
- keys, values = optree.utils.unzip2(
- optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
- )
- return values, list(keys), keys
+ # For optree / dmtree
+ keys = sorted(list(self.keys()))
+ values = [self[k] for k in keys]
+ return values, keys, keys
@classmethod
def tree_unflatten(cls, keys, values):
- from keras.src.utils.module_utils import optree
-
- # For optree
- return cls(optree.utils.safe_zip(keys, values))
+ # For optree / dmtree
+ return cls(zip(keys, values))
@tree.register_tree_node_class
@@ -286,10 +281,10 @@ def clear(self):
super().clear()
def tree_flatten(self):
- # For optree
+ # For optree / dmtree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
- # For optree
+ # For optree / dmtree
return cls(children)
diff --git a/keras/src/utils/tracking_test.py b/keras/src/utils/tracking_test.py
index dd5e9fc90037..b255e64658a0 100644
--- a/keras/src/utils/tracking_test.py
+++ b/keras/src/utils/tracking_test.py
@@ -16,8 +16,8 @@ def test_untracking_in_tracked_list(self):
),
}
)
- v1 = backend.Variable(1)
- v2 = backend.Variable(2)
+ v1 = backend.Variable(1.0)
+ v2 = backend.Variable(2.0)
lst = tracking.TrackedList([], tracker)
lst.append(v1)
lst.append(None)
@@ -67,8 +67,8 @@ def test_tuple_tracking(self):
),
}
)
- v1 = backend.Variable(1)
- v2 = backend.Variable(2)
+ v1 = backend.Variable(1.0)
+ v2 = backend.Variable(2.0)
tup = (v1, v2)
tup = tracker.track(tup)
self.assertIsInstance(tup, tuple)
@@ -86,8 +86,8 @@ def test_namedtuple_tracking(self):
),
}
)
- v1 = backend.Variable(1)
- v2 = backend.Variable(2)
+ v1 = backend.Variable(1.0)
+ v2 = backend.Variable(2.0)
nt = collections.namedtuple("NT", ["x", "y"])
tup = nt(x=v1, y=v2)
tup = tracker.track(tup)
diff --git a/keras/src/version.py b/keras/src/version.py
index 11e49a3b9267..db523fbaa13c 100644
--- a/keras/src/version.py
+++ b/keras/src/version.py
@@ -1,7 +1,7 @@
from keras.src.api_export import keras_export
# Unique source of truth for the version number.
-__version__ = "3.3.3"
+__version__ = "3.8.0"
@keras_export("keras.version")
diff --git a/keras/src/visualization/__init__.py b/keras/src/visualization/__init__.py
new file mode 100644
index 000000000000..04524f857be5
--- /dev/null
+++ b/keras/src/visualization/__init__.py
@@ -0,0 +1,2 @@
+from keras.src.visualization import draw_bounding_boxes
+from keras.src.visualization import plot_image_gallery
diff --git a/keras/src/visualization/draw_bounding_boxes.py b/keras/src/visualization/draw_bounding_boxes.py
new file mode 100644
index 000000000000..e5e93920d2e4
--- /dev/null
+++ b/keras/src/visualization/draw_bounding_boxes.py
@@ -0,0 +1,177 @@
+import numpy as np
+
+from keras.src import backend
+from keras.src import ops
+from keras.src.api_export import keras_export
+from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
+ convert_format,
+)
+
+try:
+ import cv2
+except ImportError:
+ cv2 = None
+
+
+@keras_export("keras.visualization.draw_bounding_boxes")
+def draw_bounding_boxes(
+ images,
+ bounding_boxes,
+ bounding_box_format,
+ class_mapping=None,
+ color=(128, 128, 128),
+ line_thickness=2,
+ text_thickness=1,
+ font_scale=1.0,
+ data_format=None,
+):
+ """Draws bounding boxes on images.
+
+ This function draws bounding boxes on a batch of images. It supports
+ different bounding box formats and can optionally display class labels
+ and confidences.
+
+ Args:
+ images: A batch of images as a 4D tensor or NumPy array. Shape should be
+ `(batch_size, height, width, channels)`.
+ bounding_boxes: A dictionary containing bounding box data. Should have
+ the following keys:
+ - `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)`
+ containing the bounding box coordinates in the specified format.
+ - `labels`: A tensor or array of shape `(batch_size, num_boxes)`
+ containing the class labels for each bounding box.
+ - `confidences` (Optional): A tensor or array of shape
+ `(batch_size, num_boxes)` containing the confidence scores for
+ each bounding box.
+ bounding_box_format: A string specifying the format of the bounding
+ boxes. Refer [keras-io](TODO)
+ class_mapping: A dictionary mapping class IDs (integers) to class labels
+ (strings). Used to display class labels next to the bounding boxes.
+ Defaults to None (no labels displayed).
+ color: A tuple or list representing the RGB color of the bounding boxes.
+ For example, `(255, 0, 0)` for red. Defaults to `(128, 128, 128)`.
+ line_thickness: An integer specifying the thickness of the bounding box
+ lines. Defaults to `2`.
+ text_thickness: An integer specifying the thickness of the text labels.
+ Defaults to `1`.
+ font_scale: A float specifying the scale of the font used for text
+ labels. Defaults to `1.0`.
+ data_format: A string, either `"channels_last"` or `"channels_first"`,
+ specifying the order of dimensions in the input images. Defaults to
+ the `image_data_format` value found in your Keras config file at
+ `~/.keras/keras.json`. If you never set it, then it will be
+ "channels_last".
+
+ Returns:
+ A NumPy array of the annotated images with the bounding boxes drawn.
+ The array will have the same shape as the input `images`.
+
+ Raises:
+ ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is
+ not a dictionary, or if `bounding_boxes` does not contain `"boxes"`
+ and `"labels"` keys.
+ TypeError: If `bounding_boxes` is not a dictionary.
+ ImportError: If `cv2` (OpenCV) is not installed.
+ """
+
+ if cv2 is None:
+ raise ImportError(
+ "The `draw_bounding_boxes` function requires the `cv2` package "
+ " (OpenCV). Please install it with `pip install opencv-python`."
+ )
+
+ class_mapping = class_mapping or {}
+ text_thickness = (
+ text_thickness or line_thickness
+ ) # Default text_thickness if not provided.
+ data_format = data_format or backend.image_data_format()
+ images_shape = ops.shape(images)
+ if len(images_shape) != 4:
+ raise ValueError(
+ "`images` must be batched 4D tensor. "
+ f"Received: images.shape={images_shape}"
+ )
+ if not isinstance(bounding_boxes, dict):
+ raise TypeError(
+ "`bounding_boxes` should be a dict. "
+ f"Received: bounding_boxes={bounding_boxes} of type "
+ f"{type(bounding_boxes)}"
+ )
+ if "boxes" not in bounding_boxes or "labels" not in bounding_boxes:
+ raise ValueError(
+ "`bounding_boxes` should be a dict containing 'boxes' and "
+ f"'labels' keys. Received: bounding_boxes={bounding_boxes}"
+ )
+ if data_format == "channels_last":
+ h_axis = -3
+ w_axis = -2
+ else:
+ h_axis = -2
+ w_axis = -1
+ height = images_shape[h_axis]
+ width = images_shape[w_axis]
+ bounding_boxes = bounding_boxes.copy()
+ bounding_boxes = convert_format(
+ bounding_boxes, bounding_box_format, "xyxy", height, width
+ )
+
+ # To numpy array
+ images = ops.convert_to_numpy(images).astype("uint8")
+ boxes = ops.convert_to_numpy(bounding_boxes["boxes"])
+ labels = ops.convert_to_numpy(bounding_boxes["labels"])
+ if "confidences" in bounding_boxes:
+ confidences = ops.convert_to_numpy(bounding_boxes["confidences"])
+ else:
+ confidences = None
+
+ result = []
+ batch_size = images.shape[0]
+ for i in range(batch_size):
+ _image = images[i]
+ _box = boxes[i]
+ _class = labels[i]
+ for box_i in range(_box.shape[0]):
+ x1, y1, x2, y2 = _box[box_i].astype("int32")
+ c = _class[box_i].astype("int32")
+ if c == -1:
+ continue
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ c = int(c)
+ # Draw bounding box
+ cv2.rectangle(_image, (x1, y1), (x2, y2), color, line_thickness)
+
+ if c in class_mapping:
+ label = class_mapping[c]
+ if confidences is not None:
+ conf = confidences[i][box_i]
+ label = f"{label} | {conf:.2f}"
+
+ font_x1, font_y1 = _find_text_location(
+ x1, y1, font_scale, text_thickness
+ )
+ cv2.putText(
+ img=_image,
+ text=label,
+ org=(font_x1, font_y1),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=font_scale,
+ color=color,
+ thickness=text_thickness,
+ )
+ result.append(_image)
+ return np.stack(result, axis=0)
+
+
+def _find_text_location(x, y, font_scale, thickness):
+ font_height = int(font_scale * 12)
+ target_y = y - 8
+ if target_y - (2 * font_height) > 0:
+ return x, y - 8
+
+ line_offset = thickness
+ static_offset = 3
+
+ return (
+ x + static_offset,
+ y + (2 * font_height) + line_offset + static_offset,
+ )
diff --git a/keras/src/visualization/draw_segmentation_masks.py b/keras/src/visualization/draw_segmentation_masks.py
new file mode 100644
index 000000000000..0fa8c6fbb7a1
--- /dev/null
+++ b/keras/src/visualization/draw_segmentation_masks.py
@@ -0,0 +1,109 @@
+import numpy as np
+
+from keras.src import backend
+from keras.src import ops
+from keras.src.api_export import keras_export
+
+
+@keras_export("keras.visualization.draw_segmentation_masks")
+def draw_segmentation_masks(
+ images,
+ segmentation_masks,
+ num_classes=None,
+ color_mapping=None,
+ alpha=0.8,
+ blend=True,
+ ignore_index=-1,
+ data_format=None,
+):
+ """Draws segmentation masks on images.
+
+ The function overlays segmentation masks on the input images.
+ The masks are blended with the images using the specified alpha value.
+
+ Args:
+ images: A batch of images as a 4D tensor or NumPy array. Shape
+ should be (batch_size, height, width, channels).
+ segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor
+ or NumPy array. Shape should be (batch_size, height, width) or
+ (batch_size, height, width, 1). The values represent class indices
+ starting from 1 up to `num_classes`. Class 0 is reserved for
+ the background and will be ignored if `ignore_index` is not 0.
+ num_classes: The number of segmentation classes. If `None`, it is
+ inferred from the maximum value in `segmentation_masks`.
+ color_mapping: A dictionary mapping class indices to RGB colors.
+ If `None`, a default color palette is generated. The keys should be
+ integers starting from 1 up to `num_classes`.
+ alpha: The opacity of the segmentation masks. Must be in the range
+ `[0, 1]`.
+ blend: Whether to blend the masks with the input image using the
+ `alpha` value. If `False`, the masks are drawn directly on the
+ images without blending. Defaults to `True`.
+ ignore_index: The class index to ignore. Mask pixels with this value
+ will not be drawn. Defaults to -1.
+ data_format: Image data format, either `"channels_last"` or
+ `"channels_first"`. Defaults to the `image_data_format` value found
+ in your Keras config file at `~/.keras/keras.json`. If you never
+ set it, then it will be `"channels_last"`.
+
+ Returns:
+ A NumPy array of the images with the segmentation masks overlaid.
+
+ Raises:
+ ValueError: If the input `images` is not a 4D tensor or NumPy array.
+ TypeError: If the input `segmentation_masks` is not an integer type.
+ """
+ data_format = data_format or backend.image_data_format()
+ images_shape = ops.shape(images)
+ if len(images_shape) != 4:
+ raise ValueError(
+ "`images` must be batched 4D tensor. "
+ f"Received: images.shape={images_shape}"
+ )
+ if data_format == "channels_first":
+ images = ops.transpose(images, (0, 2, 3, 1))
+ segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1))
+ images = ops.convert_to_tensor(images, dtype="float32")
+ segmentation_masks = ops.convert_to_tensor(segmentation_masks)
+
+ if not backend.is_int_dtype(segmentation_masks.dtype):
+ dtype = backend.standardize_dtype(segmentation_masks.dtype)
+ raise TypeError(
+ "`segmentation_masks` must be in integer dtype. "
+ f"Received: segmentation_masks.dtype={dtype}"
+ )
+
+ # Infer num_classes
+ if num_classes is None:
+ num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks)))
+ if color_mapping is None:
+ colors = _generate_color_palette(num_classes)
+ else:
+ colors = [color_mapping[i] for i in range(num_classes)]
+ valid_masks = ops.not_equal(segmentation_masks, ignore_index)
+ valid_masks = ops.squeeze(valid_masks, axis=-1)
+ segmentation_masks = ops.one_hot(segmentation_masks, num_classes)
+ segmentation_masks = segmentation_masks[..., 0, :]
+ segmentation_masks = ops.convert_to_numpy(segmentation_masks)
+
+ # Replace class with color
+ masks = segmentation_masks
+ masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool")
+ images_to_draw = ops.convert_to_numpy(images).copy()
+ for mask, color in zip(masks, colors):
+ color = np.array(color, dtype=images_to_draw.dtype)
+ images_to_draw[mask, ...] = color[None, :]
+ images_to_draw = ops.convert_to_tensor(images_to_draw)
+ outputs = ops.cast(images_to_draw, dtype="float32")
+
+ if blend:
+ outputs = images * (1 - alpha) + outputs * alpha
+ outputs = ops.where(valid_masks[..., None], outputs, images)
+ outputs = ops.cast(outputs, dtype="uint8")
+ outputs = ops.convert_to_numpy(outputs)
+ return outputs
+
+
+def _generate_color_palette(num_classes):
+ palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
+ return [((i * palette) % 255).tolist() for i in range(num_classes)]
diff --git a/keras/src/visualization/plot_bounding_box_gallery.py b/keras/src/visualization/plot_bounding_box_gallery.py
new file mode 100644
index 000000000000..3fe3242f718c
--- /dev/null
+++ b/keras/src/visualization/plot_bounding_box_gallery.py
@@ -0,0 +1,165 @@
+import functools
+
+import numpy as np
+
+from keras.src import backend
+from keras.src import ops
+from keras.src.api_export import keras_export
+from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
+from keras.src.visualization.plot_image_gallery import plot_image_gallery
+
+try:
+ from matplotlib import patches # For legend patches
+except ImportError:
+ patches = None
+
+
+@keras_export("keras.visualization.plot_bounding_box_gallery")
+def plot_bounding_box_gallery(
+ images,
+ bounding_box_format,
+ y_true=None,
+ y_pred=None,
+ value_range=(0, 255),
+ true_color=(0, 188, 212),
+ pred_color=(255, 235, 59),
+ line_thickness=2,
+ font_scale=1.0,
+ text_thickness=None,
+ class_mapping=None,
+ ground_truth_mapping=None,
+ prediction_mapping=None,
+ legend=False,
+ legend_handles=None,
+ rows=None,
+ cols=None,
+ data_format=None,
+ **kwargs,
+):
+ """Plots a gallery of images with bounding boxes.
+
+ This function can display both ground truth and predicted bounding boxes on
+ a set of images. It supports various bounding box formats and can include
+ class labels and a legend.
+
+ Args:
+ images: A 4D tensor or NumPy array of images. Shape should be
+ `(batch_size, height, width, channels)`.
+ bounding_box_format: The format of the bounding boxes.
+ Refer [keras-io](TODO)
+ y_true: A dictionary containing the ground truth bounding boxes and
+ labels. Should have the same structure as the `bounding_boxes`
+ argument in `keras.visualization.draw_bounding_boxes`.
+ Defaults to `None`.
+ y_pred: A dictionary containing the predicted bounding boxes and labels.
+ Should have the same structure as `y_true`. Defaults to `None`.
+ value_range: A tuple specifying the value range of the images
+ (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`.
+ true_color: A tuple of three integers representing the RGB color for the
+ ground truth bounding boxes. Defaults to `(0, 188, 212)`.
+ pred_color: A tuple of three integers representing the RGB color for the
+ predicted bounding boxes. Defaults to `(255, 235, 59)`.
+ line_thickness: The thickness of the bounding box lines. Defaults to 2.
+ font_scale: The scale of the font used for labels. Defaults to 1.0.
+ text_thickness: The thickness of the bounding box text. Defaults to
+ `line_thickness`.
+ class_mapping: A dictionary mapping class IDs to class names. Used f
+ or both ground truth and predicted boxes if `ground_truth_mapping`
+ and `prediction_mapping` are not provided. Defaults to `None`.
+ ground_truth_mapping: A dictionary mapping class IDs to class names
+ specifically for ground truth boxes. Overrides `class_mapping`
+ for ground truth. Defaults to `None`.
+ prediction_mapping: A dictionary mapping class IDs to class names
+ specifically for predicted boxes. Overrides `class_mapping` for
+ predictions. Defaults to `None`.
+ legend: A boolean indicating whether to show a legend.
+ Defaults to `False`.
+ legend_handles: A list of matplotlib `Patch` objects to use for the
+ legend. If this is provided, the `legend` argument will be ignored.
+ Defaults to `None`.
+ rows: The number of rows in the image gallery. Required if the images
+ are not batched. Defaults to `None`.
+ cols: The number of columns in the image gallery. Required if the images
+ are not batched. Defaults to `None`.
+ data_format: The image data format `"channels_last"` or
+ `"channels_first"`. Defaults to the Keras backend data format.
+ kwargs: Additional keyword arguments to be passed to
+ `keras.visualization.plot_image_gallery`.
+
+ Returns:
+ The output of `keras.visualization.plot_image_gallery`.
+
+ Raises:
+ ValueError: If `images` is not a 4D tensor/array or if both `legend` a
+ nd `legend_handles` are specified.
+ ImportError: if matplotlib is not installed
+ """
+ if patches is None:
+ raise ImportError(
+ "The `plot_bounding_box_gallery` function requires the "
+ " `matplotlib` package. Please install it with "
+ " `pip install matplotlib`."
+ )
+
+ prediction_mapping = prediction_mapping or class_mapping
+ ground_truth_mapping = ground_truth_mapping or class_mapping
+ data_format = data_format or backend.image_data_format()
+ images_shape = ops.shape(images)
+ if len(images_shape) != 4:
+ raise ValueError(
+ "`images` must be batched 4D tensor. "
+ f"Received: images.shape={images_shape}"
+ )
+ if data_format == "channels_first": # Ensure correct data format
+ images = ops.transpose(images, (0, 2, 3, 1))
+ plotted_images = ops.convert_to_numpy(images)
+
+ draw_fn = functools.partial(
+ draw_bounding_boxes,
+ bounding_box_format=bounding_box_format,
+ line_thickness=line_thickness,
+ text_thickness=text_thickness,
+ font_scale=font_scale,
+ )
+
+ if y_true is not None:
+ plotted_images = draw_fn(
+ plotted_images,
+ y_true,
+ color=true_color,
+ class_mapping=ground_truth_mapping,
+ )
+
+ if y_pred is not None:
+ plotted_images = draw_fn(
+ plotted_images,
+ y_pred,
+ color=pred_color,
+ class_mapping=prediction_mapping,
+ )
+
+ if legend:
+ if legend_handles:
+ raise ValueError(
+ "Only pass `legend` OR `legend_handles` to "
+ "`keras.visualization.plot_bounding_box_gallery()`."
+ )
+ legend_handles = [
+ patches.Patch(
+ color=np.array(true_color) / 255.0, # Normalize color
+ label="Ground Truth",
+ ),
+ patches.Patch(
+ color=np.array(pred_color) / 255.0, # Normalize color
+ label="Prediction",
+ ),
+ ]
+
+ return plot_image_gallery(
+ plotted_images,
+ value_range=value_range,
+ legend_handles=legend_handles,
+ rows=rows,
+ cols=cols,
+ **kwargs,
+ )
diff --git a/keras/src/visualization/plot_image_gallery.py b/keras/src/visualization/plot_image_gallery.py
new file mode 100644
index 000000000000..902872be5387
--- /dev/null
+++ b/keras/src/visualization/plot_image_gallery.py
@@ -0,0 +1,165 @@
+import math
+
+import numpy as np
+
+from keras.src import backend
+from keras.src import ops
+from keras.src.api_export import keras_export
+from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
+ BaseImagePreprocessingLayer,
+)
+
+try:
+ import matplotlib.pyplot as plt
+except ImportError:
+ plt = None
+
+
+def _extract_image_batch(images, num_images, batch_size):
+ """Extracts a batch of images for plotting.
+
+ Args:
+ images: The 4D tensor or NumPy array of images.
+ num_images: The number of images to extract.
+ batch_size: The original batch size of the images.
+
+ Returns:
+ A 4D tensor or NumPy array containing the extracted images.
+
+ Raises:
+ ValueError: If `images` is not a 4D tensor/array.
+ """
+
+ if len(ops.shape(images)) != 4:
+ raise ValueError(
+ "`plot_images_gallery()` requires you to "
+ "batch your `np.array` samples together."
+ )
+ num_samples = min(num_images, batch_size)
+ sample = images[:num_samples, ...]
+
+ return sample
+
+
+@keras_export("keras.visualization.plot_image_gallery")
+def plot_image_gallery(
+ images,
+ rows=None,
+ cols=None,
+ value_range=(0, 255),
+ scale=2,
+ path=None,
+ show=None,
+ transparent=True,
+ dpi=60,
+ legend_handles=None,
+ data_format=None,
+):
+ """Displays a gallery of images.
+
+ Args:
+ images: A 4D tensor or NumPy array of images. Shape should be
+ `(batch_size, height, width, channels)`.
+ value_range: A tuple specifying the value range of the images
+ (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`.
+ rows: The number of rows in the gallery. If `None`, it's calculated
+ based on the number of images and `cols`. Defaults to `None`.
+ cols: The number of columns in the gallery. If `None`, it's calculated
+ based on the number of images and `rows`. Defaults to `None`.
+ scale: A float controlling the size of the displayed images. The images
+ are scaled by this factor. Defaults to `2`.
+ path: The path to save the generated gallery image. If `None`, the
+ image is displayed using `plt.show()`. Defaults to `None`.
+ show: Whether to display the image using `plt.show()`. If `True`, the
+ image is displayed. If `False`, the image is not displayed.
+ Ignored if `path` is not `None`. Defaults to `True` if `path`
+ is `None`, `False` otherwise.
+ transparent: A boolean, whether to save the figure with a transparent
+ background. Defaults to `True`.
+ dpi: The DPI (dots per inch) for saving the figure. Defaults to 60.
+ legend_handles: A list of matplotlib `Patch` objects to use as legend
+ handles. Defaults to `None`.
+ data_format: The image data format `"channels_last"` or
+ `"channels_first"`. Defaults to the Keras backend data format.
+
+ Raises:
+ ValueError: If both `path` and `show` are set to non-`None` values or if
+ `images` is not a 4D tensor or array.
+ ImportError: if matplotlib is not installed.
+ """
+ if plt is None:
+ raise ImportError(
+ "The `plot_image_gallery` function requires the `matplotlib` "
+ "package. Please install it with `pip install matplotlib`."
+ )
+
+ if path is not None and show:
+ raise ValueError(
+ "plot_gallery() expects either `path` to be set, or `show` "
+ "to be true."
+ )
+
+ show = show if show is not None else (path is None)
+ data_format = data_format or backend.image_data_format()
+
+ batch_size = ops.shape(images)[0] if len(ops.shape(images)) == 4 else 1
+
+ rows = rows or int(math.ceil(math.sqrt(batch_size)))
+ cols = cols or int(math.ceil(batch_size // rows))
+ num_images = rows * cols
+
+ images = _extract_image_batch(images, num_images, batch_size)
+ if (
+ data_format == "channels_first"
+ ): # Ensure correct data format for plotting
+ images = ops.transpose(images, (0, 2, 3, 1))
+ # Generate subplots
+ fig, axes = plt.subplots(
+ nrows=rows,
+ ncols=cols,
+ figsize=(cols * scale, rows * scale),
+ frameon=False,
+ layout="tight",
+ squeeze=True,
+ sharex="row",
+ sharey="col",
+ )
+ fig.subplots_adjust(wspace=0, hspace=0)
+
+ if isinstance(axes, np.ndarray) and len(axes.shape) == 1:
+ expand_axis = 0 if rows == 1 else -1
+ axes = np.expand_dims(axes, expand_axis)
+
+ if legend_handles is not None:
+ fig.legend(handles=legend_handles, loc="lower center")
+
+ images = BaseImagePreprocessingLayer()._transform_value_range(
+ images=images, original_range=value_range, target_range=(0, 255)
+ )
+
+ images = ops.convert_to_numpy(images)
+ if data_format == "channels_first":
+ images = images.transpose(0, 2, 3, 1)
+
+ for row in range(rows):
+ for col in range(cols):
+ index = row * cols + col
+ current_axis = (
+ axes[row, col] if isinstance(axes, np.ndarray) else axes
+ )
+ current_axis.imshow(images[index].astype("uint8"))
+ current_axis.margins(x=0, y=0)
+ current_axis.axis("off")
+
+ if path is not None:
+ plt.savefig(
+ fname=path,
+ pad_inches=0,
+ bbox_inches="tight",
+ transparent=transparent,
+ dpi=dpi,
+ )
+ plt.close()
+ elif show:
+ plt.show()
+ plt.close()
diff --git a/keras/src/visualization/plot_segmentation_mask_gallery.py b/keras/src/visualization/plot_segmentation_mask_gallery.py
new file mode 100644
index 000000000000..1edf603ddf72
--- /dev/null
+++ b/keras/src/visualization/plot_segmentation_mask_gallery.py
@@ -0,0 +1,121 @@
+import functools
+
+import numpy as np
+
+from keras.src import backend
+from keras.src import ops
+from keras.src.api_export import keras_export
+from keras.src.visualization.draw_segmentation_masks import (
+ draw_segmentation_masks,
+)
+from keras.src.visualization.plot_image_gallery import plot_image_gallery
+
+
+@keras_export("keras.visualization.plot_segmentation_mask_gallery")
+def plot_segmentation_mask_gallery(
+ images,
+ num_classes,
+ value_range=(0, 255),
+ y_true=None,
+ y_pred=None,
+ color_mapping=None,
+ blend=True,
+ alpha=0.8,
+ ignore_index=-1,
+ data_format=None,
+ **kwargs,
+):
+ """Plots a gallery of images with corresponding segmentation masks.
+
+ Args:
+ images: A 4D tensor or NumPy array of images. Shape should be
+ `(batch_size, height, width, channels)`.
+ num_classes: The number of segmentation classes. Class indices should
+ start from `1`. Class `0` will be treated as background and
+ ignored if `ignore_index` is not 0.
+ value_range: A tuple specifying the value range of the images
+ (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`.
+ y_true: A 3D/4D tensor or NumPy array representing the ground truth
+ segmentation masks. Shape should be `(batch_size, height, width)` or
+ `(batch_size, height, width, 1)`. Defaults to `None`.
+ y_pred: A 3D/4D tensor or NumPy array representing the predicted
+ segmentation masks. Shape should be the same as `y_true`.
+ Defaults to `None`.
+ color_mapping: A dictionary mapping class indices to RGB colors.
+ If `None`, a default color palette is used. Class indices start
+ from `1`. Defaults to `None`.
+ blend: Whether to blend the masks with the input image using the
+ `alpha` value. If `False`, the masks are drawn directly on the
+ images without blending. Defaults to `True`.
+ alpha: The opacity of the segmentation masks (a float between 0 and 1).
+ Defaults to `0.8`.
+ ignore_index: The class index to ignore when drawing masks.
+ Defaults to `-1`.
+ data_format: The image data format `"channels_last"` or
+ `"channels_first"`. Defaults to the Keras backend data format.
+ kwargs: Additional keyword arguments to be passed to
+ `keras.visualization.plot_image_gallery`.
+
+ Returns:
+ The output of `keras.visualization.plot_image_gallery`.
+
+ Raises:
+ ValueError: If `images` is not a 4D tensor/array.
+ """
+ data_format = data_format or backend.image_data_format()
+ image_shape = ops.shape(images)
+ if len(image_shape) != 4:
+ raise ValueError(
+ "`images` must be batched 4D tensor. "
+ f"Received: images.shape={image_shape}"
+ )
+ if data_format == "channels_first":
+ images = ops.transpose(images, (0, 2, 3, 1))
+
+ batch_size = image_shape[0] if len(image_shape) == 4 else 1
+
+ rows = batch_size
+ cols = 1
+
+ if y_true is not None:
+ cols += 1
+
+ if y_pred is not None:
+ cols += 1
+
+ images_np = ops.convert_to_numpy(images)
+
+ draw_masks_fn = functools.partial(
+ draw_segmentation_masks,
+ num_classes=num_classes,
+ color_mapping=color_mapping,
+ alpha=alpha,
+ ignore_index=ignore_index,
+ blend=blend,
+ )
+
+ if y_true is not None:
+ if data_format == "channels_first":
+ y_true = ops.transpose(y_true, (0, 2, 3, 1))
+ y_true = ops.cast(y_true, "int32")
+ true_masks_drawn = draw_masks_fn(images_np, y_true)
+
+ if y_pred is not None:
+ if data_format == "channels_first":
+ y_pred = ops.transpose(y_pred, (0, 2, 3, 1))
+ y_pred = ops.cast(y_pred, "int32")
+ predicted_masks_drawn = draw_masks_fn(images_np, y_pred)
+
+ images_with_masks = []
+ for i in range(batch_size):
+ images_with_masks.append(images_np[i])
+ if y_true is not None:
+ images_with_masks.append(true_masks_drawn[i])
+ if y_pred is not None:
+ images_with_masks.append(predicted_masks_drawn[i])
+
+ gallery_images = np.stack(images_with_masks, axis=0)
+
+ return plot_image_gallery(
+ gallery_images, value_range=value_range, rows=rows, cols=cols, **kwargs
+ )
diff --git a/keras/src/wrappers/__init__.py b/keras/src/wrappers/__init__.py
new file mode 100644
index 000000000000..8c55aa752f5c
--- /dev/null
+++ b/keras/src/wrappers/__init__.py
@@ -0,0 +1,5 @@
+from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
+from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
+from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer
+
+__all__ = ["SKLearnClassifier", "SKLearnRegressor", "SKLearnTransformer"]
diff --git a/keras/src/wrappers/fixes.py b/keras/src/wrappers/fixes.py
new file mode 100644
index 000000000000..e16819782526
--- /dev/null
+++ b/keras/src/wrappers/fixes.py
@@ -0,0 +1,83 @@
+try:
+ import sklearn
+except ImportError:
+ sklearn = None
+
+
+def _validate_data(estimator, *args, **kwargs):
+ """Validate the input data.
+
+ wrapper for sklearn.utils.validation.validate_data or
+ BaseEstimator._validate_data depending on the scikit-learn version.
+
+ TODO: remove when minimum scikit-learn version is 1.6
+ """
+ try:
+ # scikit-learn >= 1.6
+ from sklearn.utils.validation import validate_data
+
+ return validate_data(estimator, *args, **kwargs)
+ except ImportError:
+ return estimator._validate_data(*args, **kwargs)
+ except:
+ raise
+
+
+def type_of_target(y, input_name="", *, raise_unknown=False):
+ def _raise_or_return(target_type):
+ """Depending on the value of raise_unknown, either raise an error or
+ return 'unknown'.
+ """
+ if raise_unknown and target_type == "unknown":
+ input = input_name if input_name else "data"
+ raise ValueError(f"Unknown label type for {input}: {y!r}")
+ else:
+ return target_type
+
+ target_type = sklearn.utils.multiclass.type_of_target(
+ y, input_name=input_name
+ )
+ return _raise_or_return(target_type)
+
+
+def _routing_enabled():
+ """Return whether metadata routing is enabled.
+
+ Returns:
+ enabled : bool
+ Whether metadata routing is enabled. If the config is not set, it
+ defaults to False.
+
+ TODO: remove when the config key is no longer available in scikit-learn
+ """
+ return sklearn.get_config().get("enable_metadata_routing", False)
+
+
+def _raise_for_params(params, owner, method):
+ """Raise an error if metadata routing is not enabled and params are passed.
+
+ Parameters:
+ params : dict
+ The metadata passed to a method.
+ owner : object
+ The object to which the method belongs.
+ method : str
+ The name of the method, e.g. "fit".
+
+ Raises:
+ ValueError
+ If metadata routing is not enabled and params are passed.
+ """
+ caller = (
+ f"{owner.__class__.__name__}.{method}"
+ if method
+ else owner.__class__.__name__
+ )
+ if not _routing_enabled() and params:
+ raise ValueError(
+ f"Passing extra keyword arguments to {caller} is only supported if"
+ " enable_metadata_routing=True, which you can set using"
+ " `sklearn.set_config`. See the User Guide"
+ " for more"
+ f" details. Extra parameters passed are: {set(params)}"
+ )
diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py
new file mode 100644
index 000000000000..250b12c51274
--- /dev/null
+++ b/keras/src/wrappers/sklearn_test.py
@@ -0,0 +1,160 @@
+"""Tests using Scikit-Learn's bundled estimator_checks."""
+
+from contextlib import contextmanager
+
+import pytest
+import sklearn
+from packaging.version import parse as parse_version
+from sklearn.utils.estimator_checks import parametrize_with_checks
+
+import keras
+from keras.src.backend import floatx
+from keras.src.backend import set_floatx
+from keras.src.layers import Dense
+from keras.src.layers import Input
+from keras.src.models import Model
+from keras.src.wrappers import SKLearnClassifier
+from keras.src.wrappers import SKLearnRegressor
+from keras.src.wrappers import SKLearnTransformer
+
+
+def wrapped_parametrize_with_checks(
+ estimators,
+ *,
+ legacy=True,
+ expected_failed_checks=None,
+):
+ """Wrapped `parametrize_with_checks` handling backwards compat."""
+ sklearn_version = parse_version(
+ parse_version(sklearn.__version__).base_version
+ )
+
+ if sklearn_version >= parse_version("1.6"):
+ return parametrize_with_checks(
+ estimators,
+ legacy=legacy,
+ expected_failed_checks=expected_failed_checks,
+ )
+
+ def patched_more_tags(estimator, expected_failed_checks):
+ import copy
+
+ original_tags = copy.deepcopy(sklearn.utils._tags._safe_tags(estimator))
+
+ def patched_more_tags(self):
+ original_tags.update({"_xfail_checks": expected_failed_checks})
+ return original_tags
+
+ estimator.__class__._more_tags = patched_more_tags
+ return estimator
+
+ estimators = [
+ patched_more_tags(estimator, expected_failed_checks(estimator))
+ for estimator in estimators
+ ]
+
+ # legacy is not supported and ignored
+ return parametrize_with_checks(estimators)
+
+
+def dynamic_model(X, y, loss, layers=[10]):
+ """Creates a basic MLP classifier dynamically choosing binary/multiclass
+ classification loss and ouput activations.
+ """
+ n_features_in = X.shape[1]
+ inp = Input(shape=(n_features_in,))
+
+ hidden = inp
+ for layer_size in layers:
+ hidden = Dense(layer_size, activation="relu")(hidden)
+
+ n_outputs = y.shape[1] if len(y.shape) > 1 else 1
+ out = [Dense(n_outputs, activation="softmax")(hidden)]
+ model = Model(inp, out)
+ model.compile(loss=loss, optimizer="rmsprop")
+
+ return model
+
+
+@contextmanager
+def use_floatx(x):
+ """Context manager to temporarily
+ set the keras backend precision.
+ """
+ _floatx = floatx()
+ set_floatx(x)
+ try:
+ yield
+ finally:
+ set_floatx(_floatx)
+
+
+EXPECTED_FAILED_CHECKS = {
+ "SKLearnClassifier": {
+ "check_classifiers_regression_target": "not an issue in sklearn>=1.6",
+ "check_parameters_default_constructible": (
+ "not an issue in sklearn>=1.6"
+ ),
+ "check_classifiers_one_label_sample_weights": (
+ "0 sample weight is not ignored"
+ ),
+ "check_classifiers_classes": (
+ "with small test cases the estimator returns not all classes "
+ "sometimes"
+ ),
+ "check_classifier_data_not_an_array": (
+ "This test assumes reproducibility in fit."
+ ),
+ "check_supervised_y_2d": "This test assumes reproducibility in fit.",
+ "check_fit_idempotent": "This test assumes reproducibility in fit.",
+ },
+ "SKLearnRegressor": {
+ "check_parameters_default_constructible": (
+ "not an issue in sklearn>=1.6"
+ ),
+ },
+ "SKLearnTransformer": {
+ "check_parameters_default_constructible": (
+ "not an issue in sklearn>=1.6"
+ ),
+ },
+}
+
+
+@wrapped_parametrize_with_checks(
+ estimators=[
+ SKLearnClassifier(
+ model=dynamic_model,
+ model_kwargs={
+ "loss": "categorical_crossentropy",
+ "layers": [20, 20, 20],
+ },
+ fit_kwargs={"epochs": 5},
+ ),
+ SKLearnRegressor(
+ model=dynamic_model,
+ model_kwargs={"loss": "mse"},
+ ),
+ SKLearnTransformer(
+ model=dynamic_model,
+ model_kwargs={"loss": "mse"},
+ ),
+ ],
+ expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[
+ type(estimator).__name__
+ ],
+)
+def test_sklearn_estimator_checks(estimator, check):
+ """Checks that can be passed with sklearn's default tolerances
+ and in a single epoch.
+ """
+ try:
+ check(estimator)
+ except Exception as exc:
+ if keras.config.backend() in ["numpy", "openvino"] and (
+ isinstance(exc, NotImplementedError)
+ or "NotImplementedError" in str(exc)
+ ):
+ pytest.xfail("Backend not implemented")
+ else:
+ raise
diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py
new file mode 100644
index 000000000000..77c3488b6ee9
--- /dev/null
+++ b/keras/src/wrappers/sklearn_wrapper.py
@@ -0,0 +1,488 @@
+import copy
+
+import numpy as np
+
+from keras.src.api_export import keras_export
+from keras.src.models.cloning import clone_model
+from keras.src.models.model import Model
+from keras.src.wrappers.fixes import _routing_enabled
+from keras.src.wrappers.fixes import _validate_data
+from keras.src.wrappers.fixes import type_of_target
+from keras.src.wrappers.utils import TargetReshaper
+from keras.src.wrappers.utils import _check_model
+from keras.src.wrappers.utils import assert_sklearn_installed
+
+try:
+ import sklearn
+ from sklearn.base import BaseEstimator
+ from sklearn.base import ClassifierMixin
+ from sklearn.base import RegressorMixin
+ from sklearn.base import TransformerMixin
+except ImportError:
+ sklearn = None
+
+ class BaseEstimator:
+ pass
+
+ class ClassifierMixin:
+ pass
+
+ class RegressorMixin:
+ pass
+
+ class TransformerMixin:
+ pass
+
+
+class SKLBase(BaseEstimator):
+ """Base class for scikit-learn wrappers.
+
+ Note that there are sources of randomness in model initialization and
+ training. Refer to [Reproducibility in Keras Models](
+ https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to
+ control randomness.
+
+ Args:
+ model: `Model`.
+ An instance of `Model`, or a callable returning such an object.
+ Note that if input is a `Model`, it will be cloned using
+ `keras.models.clone_model` before being fitted, unless
+ `warm_start=True`.
+ The `Model` instance needs to be passed as already compiled.
+ If callable, it must accept at least `X` and `y` as keyword
+ arguments. Other arguments must be accepted if passed as
+ `model_kwargs` by the user.
+ warm_start: bool, defaults to `False`.
+ Whether to reuse the model weights from the previous fit. If `True`,
+ the given model won't be cloned and the weights from the previous
+ fit will be reused.
+ model_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model`, if `model` is callable.
+ fit_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model.fit`. These can also be passed
+ directly to the `fit` method of the scikit-learn wrapper. The
+ values passed directly to the `fit` method take precedence over
+ these.
+
+ Attributes:
+ model_ : `Model`
+ The fitted model.
+ history_ : dict
+ The history of the fit, returned by `model.fit`.
+ """
+
+ def __init__(
+ self,
+ model,
+ warm_start=False,
+ model_kwargs=None,
+ fit_kwargs=None,
+ ):
+ assert_sklearn_installed(self.__class__.__name__)
+ self.model = model
+ self.warm_start = warm_start
+ self.model_kwargs = model_kwargs
+ self.fit_kwargs = fit_kwargs
+
+ def _more_tags(self):
+ return {"non_deterministic": True}
+
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.non_deterministic = True
+ return tags
+
+ def __sklearn_clone__(self):
+ """Return a deep copy of the model.
+
+ This is used by the `sklearn.base.clone` function.
+ """
+ model = (
+ self.model if callable(self.model) else copy.deepcopy(self.model)
+ )
+ return type(self)(
+ model=model,
+ warm_start=self.warm_start,
+ model_kwargs=self.model_kwargs,
+ )
+
+ @property
+ def epoch_(self):
+ """The current training epoch."""
+ return getattr(self, "history_", {}).get("epoch", 0)
+
+ def set_fit_request(self, **kwargs):
+ """Set requested parameters by the fit method.
+
+ Please see [scikit-learn's metadata routing](
+ https://scikit-learn.org/stable/metadata_routing.html) for more
+ details.
+
+
+ Arguments:
+ kwargs : dict
+ Arguments should be of the form `param_name=alias`, and `alias`
+ can be one of `{True, False, None, str}`.
+
+ Returns:
+ self
+ """
+ if not _routing_enabled():
+ raise RuntimeError(
+ "This method is only available when metadata routing is "
+ "enabled. You can enable it using "
+ "sklearn.set_config(enable_metadata_routing=True)."
+ )
+
+ self._metadata_request = sklearn.utils.metadata_routing.MetadataRequest(
+ owner=self.__class__.__name__
+ )
+ for param, alias in kwargs.items():
+ self._metadata_request.score.add_request(param=param, alias=alias)
+ return self
+
+ def _get_model(self, X, y):
+ if isinstance(self.model, Model):
+ return clone_model(self.model)
+ else:
+ args = self.model_kwargs or {}
+ return self.model(X=X, y=y, **args)
+
+ def fit(self, X, y, **kwargs):
+ """Fit the model.
+
+ Args:
+ X: array-like, shape=(n_samples, n_features)
+ The input samples.
+ y: array-like, shape=(n_samples,) or (n_samples, n_outputs)
+ The targets.
+ **kwargs: keyword arguments passed to `model.fit`
+ """
+ X, y = _validate_data(self, X, y)
+ y = self._process_target(y, reset=True)
+ model = self._get_model(X, y)
+ _check_model(model)
+
+ fit_kwargs = self.fit_kwargs or {}
+ fit_kwargs.update(kwargs)
+ self.history_ = model.fit(X, y, **fit_kwargs)
+
+ self.model_ = model
+ return self
+
+ def predict(self, X):
+ """Predict using the model."""
+ sklearn.base.check_is_fitted(self)
+ X = _validate_data(self, X, reset=False)
+ raw_output = self.model_.predict(X)
+ return self._reverse_process_target(raw_output)
+
+ def _process_target(self, y, reset=False):
+ """Regressors are NOOP here, classifiers do OHE."""
+ # This is here to raise the right error in case of invalid target
+ type_of_target(y, raise_unknown=True)
+ if reset:
+ self._target_encoder = TargetReshaper().fit(y)
+ return self._target_encoder.transform(y)
+
+ def _reverse_process_target(self, y):
+ """Regressors are NOOP here, classifiers reverse OHE."""
+ return self._target_encoder.inverse_transform(y)
+
+
+@keras_export("keras.wrappers.SKLearnClassifier")
+class SKLearnClassifier(ClassifierMixin, SKLBase):
+ """scikit-learn compatible classifier wrapper for Keras models.
+
+ Note that there are sources of randomness in model initialization and
+ training. Refer to [Reproducibility in Keras Models](
+ https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to
+ control randomness.
+
+ Args:
+ model: `Model`.
+ An instance of `Model`, or a callable returning such an object.
+ Note that if input is a `Model`, it will be cloned using
+ `keras.models.clone_model` before being fitted, unless
+ `warm_start=True`.
+ The `Model` instance needs to be passed as already compiled.
+ If callable, it must accept at least `X` and `y` as keyword
+ arguments. Other arguments must be accepted if passed as
+ `model_kwargs` by the user.
+ warm_start: bool, defaults to `False`.
+ Whether to reuse the model weights from the previous fit. If `True`,
+ the given model won't be cloned and the weights from the previous
+ fit will be reused.
+ model_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model`, if `model` is callable.
+ fit_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model.fit`. These can also be passed
+ directly to the `fit` method of the scikit-learn wrapper. The
+ values passed directly to the `fit` method take precedence over
+ these.
+
+ Attributes:
+ model_ : `Model`
+ The fitted model.
+ history_ : dict
+ The history of the fit, returned by `model.fit`.
+ classes_ : array-like, shape=(n_classes,)
+ The classes labels.
+
+ Example:
+ Here we use a function which creates a basic MLP model dynamically
+ choosing the input and output shapes. We will use this to create our
+ scikit-learn model.
+
+ ``` python
+ from keras.src.layers import Dense, Input, Model
+
+ def dynamic_model(X, y, loss, layers=[10]):
+ # Creates a basic MLP model dynamically choosing the input and
+ # output shapes.
+ n_features_in = X.shape[1]
+ inp = Input(shape=(n_features_in,))
+
+ hidden = inp
+ for layer_size in layers:
+ hidden = Dense(layer_size, activation="relu")(hidden)
+
+ n_outputs = y.shape[1] if len(y.shape) > 1 else 1
+ out = [Dense(n_outputs, activation="softmax")(hidden)]
+ model = Model(inp, out)
+ model.compile(loss=loss, optimizer="rmsprop")
+
+ return model
+ ```
+
+ You can then use this function to create a scikit-learn compatible model
+ and fit it on some data.
+
+ ``` python
+ from sklearn.datasets import make_classification
+ from keras.wrappers import SKLearnClassifier
+
+ X, y = make_classification(n_samples=1000, n_features=10, n_classes=3)
+ est = SKLearnClassifier(
+ model=dynamic_model,
+ model_kwargs={
+ "loss": "categorical_crossentropy",
+ "layers": [20, 20, 20],
+ },
+ )
+
+ est.fit(X, y, epochs=5)
+ ```
+ """
+
+ def _process_target(self, y, reset=False):
+ """Classifiers do OHE."""
+ target_type = type_of_target(y, raise_unknown=True)
+ if target_type not in ["binary", "multiclass"]:
+ raise ValueError(
+ "Only binary and multiclass target types are supported."
+ f" Target type: {target_type}"
+ )
+ if reset:
+ self._target_encoder = sklearn.pipeline.make_pipeline(
+ TargetReshaper(),
+ sklearn.preprocessing.OneHotEncoder(sparse_output=False),
+ ).fit(y)
+ self.classes_ = np.unique(y)
+ if len(self.classes_) == 1:
+ raise ValueError(
+ "Classifier can't train when only one class is present."
+ )
+ return self._target_encoder.transform(y)
+
+ def _more_tags(self):
+ # required to be compatible with scikit-learn<1.6
+ return {"poor_score": True}
+
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.classifier_tags.poor_score = True
+ return tags
+
+
+@keras_export("keras.wrappers.SKLearnRegressor")
+class SKLearnRegressor(RegressorMixin, SKLBase):
+ """scikit-learn compatible regressor wrapper for Keras models.
+
+ Note that there are sources of randomness in model initialization and
+ training. Refer to [Reproducibility in Keras Models](
+ https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to
+ control randomness.
+
+ Args:
+ model: `Model`.
+ An instance of `Model`, or a callable returning such an object.
+ Note that if input is a `Model`, it will be cloned using
+ `keras.models.clone_model` before being fitted, unless
+ `warm_start=True`.
+ The `Model` instance needs to be passed as already compiled.
+ If callable, it must accept at least `X` and `y` as keyword
+ arguments. Other arguments must be accepted if passed as
+ `model_kwargs` by the user.
+ warm_start: bool, defaults to `False`.
+ Whether to reuse the model weights from the previous fit. If `True`,
+ the given model won't be cloned and the weights from the previous
+ fit will be reused.
+ model_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model`, if `model` is callable.
+ fit_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model.fit`. These can also be passed
+ directly to the `fit` method of the scikit-learn wrapper. The
+ values passed directly to the `fit` method take precedence over
+ these.
+
+ Attributes:
+ model_ : `Model`
+ The fitted model.
+
+ Example:
+ Here we use a function which creates a basic MLP model dynamically
+ choosing the input and output shapes. We will use this to create our
+ scikit-learn model.
+
+ ``` python
+ from keras.src.layers import Dense, Input, Model
+
+ def dynamic_model(X, y, loss, layers=[10]):
+ # Creates a basic MLP model dynamically choosing the input and
+ # output shapes.
+ n_features_in = X.shape[1]
+ inp = Input(shape=(n_features_in,))
+
+ hidden = inp
+ for layer_size in layers:
+ hidden = Dense(layer_size, activation="relu")(hidden)
+
+ n_outputs = y.shape[1] if len(y.shape) > 1 else 1
+ out = [Dense(n_outputs, activation="softmax")(hidden)]
+ model = Model(inp, out)
+ model.compile(loss=loss, optimizer="rmsprop")
+
+ return model
+ ```
+
+ You can then use this function to create a scikit-learn compatible model
+ and fit it on some data.
+
+ ``` python
+ from sklearn.datasets import make_regression
+ from keras.wrappers import SKLearnRegressor
+
+ X, y = make_regression(n_samples=1000, n_features=10)
+ est = SKLearnRegressor(
+ model=dynamic_model,
+ model_kwargs={
+ "loss": "mse",
+ "layers": [20, 20, 20],
+ },
+ )
+
+ est.fit(X, y, epochs=5)
+ ```
+ """
+
+ def _more_tags(self):
+ # required to be compatible with scikit-learn<1.6
+ return {"poor_score": True}
+
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.regressor_tags.poor_score = True
+ return tags
+
+
+@keras_export("keras.wrappers.SKLearnTransformer")
+class SKLearnTransformer(TransformerMixin, SKLBase):
+ """scikit-learn compatible transformer wrapper for Keras models.
+
+ Note that this is a scikit-learn compatible transformer, and not a
+ transformer in the deep learning sense.
+
+ Also note that there are sources of randomness in model initialization and
+ training. Refer to [Reproducibility in Keras Models](
+ https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to
+ control randomness.
+
+ Args:
+ model: `Model`.
+ An instance of `Model`, or a callable returning such an object.
+ Note that if input is a `Model`, it will be cloned using
+ `keras.models.clone_model` before being fitted, unless
+ `warm_start=True`.
+ The `Model` instance needs to be passed as already compiled.
+ If callable, it must accept at least `X` and `y` as keyword
+ arguments. Other arguments must be accepted if passed as
+ `model_kwargs` by the user.
+ warm_start: bool, defaults to `False`.
+ Whether to reuse the model weights from the previous fit. If `True`,
+ the given model won't be cloned and the weights from the previous
+ fit will be reused.
+ model_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model`, if `model` is callable.
+ fit_kwargs: dict, defaults to `None`.
+ Keyword arguments passed to `model.fit`. These can also be passed
+ directly to the `fit` method of the scikit-learn wrapper. The
+ values passed directly to the `fit` method take precedence over
+ these.
+
+ Attributes:
+ model_ : `Model`
+ The fitted model.
+ history_ : dict
+ The history of the fit, returned by `model.fit`.
+
+ Example:
+ A common use case for a scikit-learn transformer, is to have a step
+ which gives you the embedding of your data. Here we assume
+ `my_package.my_model` is a Keras model which takes the input and gives
+ embeddings of the data, and `my_package.my_data` is your dataset loader.
+
+ ``` python
+ from my_package import my_model, my_data
+ from keras.wrappers import SKLearnTransformer
+ from sklearn.frozen import FrozenEstimator # requires scikit-learn>=1.6
+ from sklearn.pipeline import make_pipeline
+ from sklearn.ensemble import HistGradientBoostingClassifier
+
+ X, y = my_data()
+
+ trs = FrozenEstimator(SKLearnTransformer(model=my_model))
+ pipe = make_pipeline(trs, HistGradientBoostingClassifier())
+ pipe.fit(X, y)
+ ```
+
+ Note that in the above example, `FrozenEstimator` prevents any further
+ training of the transformer step in the pipeline, which can be the case
+ if you don't want to change the embedding model at hand.
+ """
+
+ def transform(self, X):
+ """Transform the data.
+
+ Args:
+ X: array-like, shape=(n_samples, n_features)
+ The input samples.
+
+ Returns:
+ X_transformed: array-like, shape=(n_samples, n_features)
+ The transformed data.
+ """
+ sklearn.base.check_is_fitted(self)
+ X = _validate_data(self, X, reset=False)
+ return self.model_.predict(X)
+
+ def _more_tags(self):
+ # required to be compatible with scikit-learn<1.6
+ return {
+ "preserves_dtype": [],
+ }
+
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.transformer_tags.preserves_dtype = []
+ return tags
diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py
new file mode 100644
index 000000000000..301c4b562912
--- /dev/null
+++ b/keras/src/wrappers/utils.py
@@ -0,0 +1,87 @@
+try:
+ import sklearn
+ from sklearn.base import BaseEstimator
+ from sklearn.base import TransformerMixin
+except ImportError:
+ sklearn = None
+
+ class BaseEstimator:
+ pass
+
+ class TransformerMixin:
+ pass
+
+
+def assert_sklearn_installed(symbol_name):
+ if sklearn is None:
+ raise ImportError(
+ f"{symbol_name} requires `scikit-learn` to be installed. "
+ "Run `pip install scikit-learn` to install it."
+ )
+
+
+def _check_model(model):
+ """Check whether the model need sto be compiled."""
+ # compile model if user gave us an un-compiled model
+ if not model.compiled or not model.loss or not model.optimizer:
+ raise RuntimeError(
+ "Given model needs to be compiled, and have a loss and an "
+ "optimizer."
+ )
+
+
+class TargetReshaper(TransformerMixin, BaseEstimator):
+ """Convert 1D targets to 2D and back.
+
+ For use in pipelines with transformers that only accept
+ 2D inputs, like OneHotEncoder and OrdinalEncoder.
+
+ Attributes:
+ ndim_ : int
+ Dimensions of y that the transformer was trained on.
+ """
+
+ def fit(self, y):
+ """Fit the transformer to a target y.
+
+ Returns:
+ TargetReshaper
+ A reference to the current instance of TargetReshaper.
+ """
+ self.ndim_ = y.ndim
+ return self
+
+ def transform(self, y):
+ """Makes 1D y 2D.
+
+ Args:
+ y : np.ndarray
+ Target y to be transformed.
+
+ Returns:
+ np.ndarray
+ A numpy array, of dimension at least 2.
+ """
+ if y.ndim == 1:
+ return y.reshape(-1, 1)
+ return y
+
+ def inverse_transform(self, y):
+ """Revert the transformation of transform.
+
+ Args:
+ y: np.ndarray
+ Transformed numpy array.
+
+ Returns:
+ np.ndarray
+ If the transformer was fit to a 1D numpy array,
+ and a 2D numpy array with a singleton second dimension
+ is passed, it will be squeezed back to 1D. Otherwise, it
+ will eb left untouched.
+ """
+ sklearn.base.check_is_fitted(self)
+ xp, _ = sklearn.utils._array_api.get_namespace(y)
+ if self.ndim_ == 1 and y.ndim == 2:
+ return xp.squeeze(y, axis=1)
+ return y
diff --git a/pip_build.py b/pip_build.py
index 66e7578eee25..f0022e415cd6 100644
--- a/pip_build.py
+++ b/pip_build.py
@@ -29,22 +29,20 @@
package = "keras"
build_directory = "tmp_build_dir"
dist_directory = "dist"
-to_copy = ["setup.py", "README.md"]
+to_copy = ["pyproject.toml", "README.md"]
def export_version_string(version, is_nightly=False, rc_index=None):
"""Export Version and Package Name."""
if is_nightly:
date = datetime.datetime.now()
- version += f".dev{date.strftime('%Y%m%d%H')}"
- # Replaces `name="keras"` string in `setup.py` with `keras-nightly`
- with open("setup.py") as f:
- setup_contents = f.read()
- with open("setup.py", "w") as f:
- setup_contents = setup_contents.replace(
- 'name="keras"', 'name="keras-nightly"'
- )
- f.write(setup_contents)
+ version += f".dev{date:%Y%m%d%H}"
+ # Update `name = "keras"` with "keras-nightly"
+ pyproj_pth = pathlib.Path("pyproject.toml")
+ pyproj_str = pyproj_pth.read_text().replace(
+ 'name = "keras"', 'name = "keras-nightly"'
+ )
+ pyproj_pth.write_text(pyproj_str)
elif rc_index is not None:
version += "rc" + str(rc_index)
diff --git a/pyproject.toml b/pyproject.toml
index e016bb363fba..773f53c68f18 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,22 +1,76 @@
-[tool.black]
+[build-system]
+requires = ["setuptools >=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "keras"
+authors = [
+ {name = "Keras team", email = "keras-users@googlegroups.com"},
+]
+description = "Multi-backend Keras"
+readme = "README.md"
+requires-python = ">=3.9"
+license = {text = "Apache License 2.0"}
+dynamic = ["version"]
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3 :: Only",
+ "Operating System :: Unix",
+ "Operating System :: MacOS",
+ "Intended Audience :: Science/Research",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Software Development",
+]
+dependencies = [
+ "absl-py",
+ "numpy",
+ "rich",
+ "namex",
+ "h5py",
+ "optree",
+ "ml-dtypes",
+ "packaging",
+]
+# Run also: pip install -r requirements.txt
+
+[project.urls]
+Home = "https://keras.io/"
+Repository = "https://github.com/keras-team/keras"
+
+[tool.setuptools.dynamic]
+version = {attr = "keras.src.version.__version__"}
+
+[tool.setuptools.packages.find]
+include = ["keras", "keras.*"]
+
+[tool.ruff]
line-length = 80
-# black needs this to be a regex
-# to add more exclude expressions
-# append `| ` (e.g. `| .*_test\\.py`) to this list
-extend-exclude = """
-(
- examples/
-)
-"""
-
-[tool.isort]
-profile = "black"
-force_single_line = "True"
-known_first_party = ["keras_core", "tests"]
-default_section = "THIRDPARTY"
-line_length = 80
-extend_skip_glob=["examples/*", "guides/*"]
+[tool.ruff.lint]
+select = [
+ "E", # pycodestyle error
+ "F", # Pyflakes
+ "I", # isort
+]
+ignore = [
+ "E722", # do not use bare 'except'
+ "E741", # ambiguous variable name
+ "E731", # do not assign a `lambda` expression, use a `def`
+]
+
+[tool.ruff.lint.per-file-ignores]
+"**/__init__.py" = ["E501", "F401"] # lines too long; imported but unused
+"**/random.py" = ["F401"] # imported but unused
+"examples/*" = ["I", "E"]
+"guides/*" = ["I", "E", "F"]
+
+[tool.ruff.lint.isort]
+force-single-line = true
+known-first-party = ["keras"]
[tool.pytest.ini_options]
filterwarnings = [
@@ -43,13 +97,13 @@ exclude_lines = [
]
omit = [
"*/*_test.py",
- "keras_core/legacy/*",
+ "keras/src/legacy/*",
]
[tool.coverage.run]
branch = true
omit = [
"*/*_test.py",
- "keras_core/legacy/*",
+ "keras/src/legacy/*",
]
diff --git a/requirements-common.txt b/requirements-common.txt
index f645c7ba9409..51c682f9ef41 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -1,10 +1,9 @@
namex>=0.0.8
-black>=22
-flake8
-isort
+ruff
pytest
numpy
scipy
+scikit-learn
pandas
absl-py
requests
@@ -17,3 +16,9 @@ rich
build
optree
pytest-cov
+packaging
+# for tree_test.py
+dm_tree
+coverage!=7.6.5 # 7.6.5 breaks CI
+# for onnx_test.py
+onnxruntime
diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt
index e21c1cb1c5bc..7b1d2166f638 100644
--- a/requirements-jax-cuda.txt
+++ b/requirements-jax-cuda.txt
@@ -1,15 +1,17 @@
# Tensorflow cpu-only version (needed for testing).
-tensorflow-cpu~=2.16.1 # Pin to TF 2.16
+tensorflow-cpu~=2.18.0
+tf2onnx
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
-torch>=2.1.0, <2.3.0
+torch>=2.1.0
torchvision>=0.16.0
+torch-xla
# Jax with cuda support.
-# TODO: 0.4.24 has an updated Cuda version breaks Jax CI.
+# TODO: Higher version breaks CI.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-jax[cuda12_pip]==0.4.23
+jax[cuda12]==0.4.28
flax
-r requirements-common.txt
diff --git a/requirements-openvino.txt b/requirements-openvino.txt
new file mode 100644
index 000000000000..d05f50f151f3
--- /dev/null
+++ b/requirements-openvino.txt
@@ -0,0 +1,5 @@
+# OpenVINO
+openvino
+
+# All testing deps.
+-r requirements.txt
diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt
index f3b946ddcfee..fed601f658f2 100644
--- a/requirements-tensorflow-cuda.txt
+++ b/requirements-tensorflow-cuda.txt
@@ -1,10 +1,12 @@
# Tensorflow with cuda support.
-tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16
+tensorflow[and-cuda]~=2.18.0
+tf2onnx
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
-torch>=2.1.0, <2.3.0
+torch>=2.1.0
torchvision>=0.16.0
+torch-xla
# Jax cpu-only version (needed for testing).
jax[cpu]
diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt
index e0a71cc4e6a3..d165faa16280 100644
--- a/requirements-torch-cuda.txt
+++ b/requirements-torch-cuda.txt
@@ -1,10 +1,12 @@
# Tensorflow cpu-only version (needed for testing).
-tensorflow-cpu~=2.16.1 # Pin to TF 2.16
+tensorflow-cpu~=2.18.0
+tf2onnx
# Torch with cuda support.
--extra-index-url https://download.pytorch.org/whl/cu121
-torch==2.2.1+cu121
-torchvision==0.17.1+cu121
+torch==2.5.1+cu121
+torchvision==0.20.1+cu121
+torch-xla
# Jax cpu-only version (needed for testing).
jax[cpu]
diff --git a/requirements.txt b/requirements.txt
index 7c0000eed07e..c2054e28b907 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,15 @@
# Tensorflow.
-tensorflow-cpu~=2.16.1 # Pin to TF 2.16
+tensorflow-cpu~=2.18.0;sys_platform != 'darwin'
+tensorflow~=2.18.0;sys_platform == 'darwin'
+tf_keras
+tf2onnx
# Torch.
# TODO: Pin to < 2.3.0 (GitHub issue #19602)
--extra-index-url https://download.pytorch.org/whl/cpu
-torch>=2.1.0, <2.3.0
+torch>=2.1.0
torchvision>=0.16.0
+torch-xla;sys_platform != 'darwin'
# Jax.
jax[cpu]
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index df2bc8fe311b..000000000000
--- a/setup.cfg
+++ /dev/null
@@ -1,33 +0,0 @@
-[flake8]
-ignore =
- # Conflicts with black
- E203
- # defaults flake8 ignores
- E121,E123,E126,E226,E24,E704,W503,W504
- # Function name should be lowercase
- N802
- # lowercase ... imported as non lowercase
- # Useful to ignore for "import keras.backend as K"
- N812
- # do not use bare 'except'
- E722
- # too many "#"
- E266
-
-exclude =
- *_pb2.py,
- *_pb2_grpc.py,
-
-extend-exclude =
- # excluding examples/ and guides/ since they are formatted as follow-along guides
- examples,
- guides,
-
-
-#imported but unused in __init__.py, that's ok.
-per-file-ignores =
- # import not used
- **/__init__.py:F401
- **/random.py:F401
-
-max-line-length = 80
diff --git a/setup.py b/setup.py
deleted file mode 100644
index a78f07dda269..000000000000
--- a/setup.py
+++ /dev/null
@@ -1,63 +0,0 @@
-"""Setup script."""
-
-import os
-import pathlib
-
-from setuptools import find_packages
-from setuptools import setup
-
-
-def read(rel_path):
- here = os.path.abspath(os.path.dirname(__file__))
- with open(os.path.join(here, rel_path)) as fp:
- return fp.read()
-
-
-def get_version(rel_path):
- for line in read(rel_path).splitlines():
- if line.startswith("__version__"):
- delim = '"' if '"' in line else "'"
- return line.split(delim)[1]
- raise RuntimeError("Unable to find version string.")
-
-
-HERE = pathlib.Path(__file__).parent
-README = (HERE / "README.md").read_text()
-VERSION = get_version("keras/src/version.py")
-
-setup(
- name="keras",
- description="Multi-backend Keras.",
- long_description_content_type="text/markdown",
- long_description=README,
- version=VERSION,
- url="https://github.com/keras-team/keras",
- author="Keras team",
- author_email="keras-users@googlegroups.com",
- license="Apache License 2.0",
- install_requires=[
- "absl-py",
- "numpy",
- "rich",
- "namex",
- "h5py",
- "optree",
- "ml-dtypes",
- ],
- # Supported Python versions
- python_requires=">=3.9",
- classifiers=[
- "Development Status :: 4 - Beta",
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3 :: Only",
- "Operating System :: Unix",
- "Operating System :: MacOS",
- "Intended Audience :: Science/Research",
- "Topic :: Scientific/Engineering",
- "Topic :: Software Development",
- ],
- packages=find_packages(exclude=("*_test.py",)),
-)
diff --git a/shell/format.sh b/shell/format.sh
index f2992e44f895..4e1d191dbda2 100755
--- a/shell/format.sh
+++ b/shell/format.sh
@@ -1,11 +1,9 @@
#!/bin/bash
-set -Eeuo pipefail
+set -Euo pipefail
base_dir=$(dirname $(dirname $0))
-isort --sp "${base_dir}/pyproject.toml" .
+ruff check --config "${base_dir}/pyproject.toml" --fix .
-black --config "${base_dir}/pyproject.toml" .
-
-flake8 --config "${base_dir}/setup.cfg" .
+ruff format --config "${base_dir}/pyproject.toml" .
diff --git a/shell/lint.sh b/shell/lint.sh
index 8a10a2073562..37e9276f2257 100755
--- a/shell/lint.sh
+++ b/shell/lint.sh
@@ -1,11 +1,12 @@
#!/bin/bash
-set -Eeuo pipefail
+set -Euo pipefail
base_dir=$(dirname $(dirname $0))
-isort --sp "${base_dir}/pyproject.toml" --check .
+ruff check --config "${base_dir}/pyproject.toml" .
+exitcode=$?
-black --config "${base_dir}/pyproject.toml" --check .
-
-flake8 --config "${base_dir}/setup.cfg" .
+ruff format --check --config "${base_dir}/pyproject.toml" .
+exitcode=$(($exitcode + $?))
+exit $exitcode