diff --git a/keras/src/datasets/cifar10.py b/keras/src/datasets/cifar10.py index f5d8a617866..4848e0409f1 100644 --- a/keras/src/datasets/cifar10.py +++ b/keras/src/datasets/cifar10.py @@ -60,12 +60,12 @@ def load_data(): assert y_test.shape == (10000, 1) ``` """ - dirname = "cifar-10-batches-py" + dirname = "cifar-10-batches-py-target" origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" path = get_file( fname=dirname, origin=origin, - untar=True, + extract=True, file_hash=( # noqa: E501 "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" ), @@ -76,6 +76,8 @@ def load_data(): x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8") y_train = np.empty((num_train_samples,), dtype="uint8") + # batches are within an inner folder + path = os.path.join(path, "cifar-10-batches-py") for i in range(1, 6): fpath = os.path.join(path, "data_batch_" + str(i)) ( diff --git a/keras/src/datasets/cifar100.py b/keras/src/datasets/cifar100.py index 5d58b5b1878..e27421a6cf0 100644 --- a/keras/src/datasets/cifar100.py +++ b/keras/src/datasets/cifar100.py @@ -58,17 +58,18 @@ def load_data(label_mode="fine"): f"Received: label_mode={label_mode}." ) - dirname = "cifar-100-python" + dirname = "cifar-100-python-target" origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" path = get_file( fname=dirname, origin=origin, - untar=True, + extract=True, file_hash=( # noqa: E501 "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" ), ) + path = os.path.join(path, "cifar-100-python") fpath = os.path.join(path, "train") x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels")